http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinRContext.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinRContext.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinRContext.java new file mode 100644 index 0000000..80ea03b --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinRContext.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +/** + * Contains the Spark and Zeppelin Contexts made available to SparkR. + */ +public class ZeppelinRContext { + private static SparkContext sparkContext; + private static SQLContext sqlContext; + private static SparkZeppelinContext zeppelinContext; + private static Object sparkSession; + private static JavaSparkContext javaSparkContext; + + public static void setSparkContext(SparkContext sparkContext) { + ZeppelinRContext.sparkContext = sparkContext; + } + + public static void setZeppelinContext(SparkZeppelinContext zeppelinContext) { + ZeppelinRContext.zeppelinContext = zeppelinContext; + } + + public static void setSqlContext(SQLContext sqlContext) { + ZeppelinRContext.sqlContext = sqlContext; + } + + public static void setSparkSession(Object sparkSession) { + ZeppelinRContext.sparkSession = sparkSession; + } + + public static SparkContext getSparkContext() { + return sparkContext; + } + + public static SQLContext getSqlContext() { + return sqlContext; + } + + public static SparkZeppelinContext getZeppelinContext() { + return zeppelinContext; + } + + public static Object getSparkSession() { + return sparkSession; + } + + public static void setJavaSparkContext(JavaSparkContext jsc) { javaSparkContext = jsc; } + + public static JavaSparkContext getJavaSparkContext() { return javaSparkContext; } +}
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyContext.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyContext.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyContext.java new file mode 100644 index 0000000..0235fc6 --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyContext.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark.dep; + +import java.io.File; +import java.net.MalformedURLException; +import java.util.LinkedList; +import java.util.List; + +import org.apache.zeppelin.dep.Booter; +import org.apache.zeppelin.dep.Dependency; +import org.apache.zeppelin.dep.Repository; + +import org.sonatype.aether.RepositorySystem; +import org.sonatype.aether.RepositorySystemSession; +import org.sonatype.aether.artifact.Artifact; +import org.sonatype.aether.collection.CollectRequest; +import org.sonatype.aether.graph.DependencyFilter; +import org.sonatype.aether.repository.RemoteRepository; +import org.sonatype.aether.repository.Authentication; +import org.sonatype.aether.resolution.ArtifactResolutionException; +import org.sonatype.aether.resolution.ArtifactResult; +import org.sonatype.aether.resolution.DependencyRequest; +import org.sonatype.aether.resolution.DependencyResolutionException; +import org.sonatype.aether.util.artifact.DefaultArtifact; +import org.sonatype.aether.util.artifact.JavaScopes; +import org.sonatype.aether.util.filter.DependencyFilterUtils; +import org.sonatype.aether.util.filter.PatternExclusionsDependencyFilter; + + +/** + * + */ +public class SparkDependencyContext { + List<Dependency> dependencies = new LinkedList<>(); + List<Repository> repositories = new LinkedList<>(); + + List<File> files = new LinkedList<>(); + List<File> filesDist = new LinkedList<>(); + private RepositorySystem system = Booter.newRepositorySystem(); + private RepositorySystemSession session; + private RemoteRepository mavenCentral = Booter.newCentralRepository(); + private RemoteRepository mavenLocal = Booter.newLocalRepository(); + private List<RemoteRepository> additionalRepos = new LinkedList<>(); + + public SparkDependencyContext(String localRepoPath, String additionalRemoteRepository) { + session = Booter.newRepositorySystemSession(system, localRepoPath); + addRepoFromProperty(additionalRemoteRepository); + } + + public Dependency load(String lib) { + Dependency dep = new Dependency(lib); + + if (dependencies.contains(dep)) { + dependencies.remove(dep); + } + dependencies.add(dep); + return dep; + } + + public Repository addRepo(String name) { + Repository rep = new Repository(name); + repositories.add(rep); + return rep; + } + + public void reset() { + dependencies = new LinkedList<>(); + repositories = new LinkedList<>(); + + files = new LinkedList<>(); + filesDist = new LinkedList<>(); + } + + private void addRepoFromProperty(String listOfRepo) { + if (listOfRepo != null) { + String[] repos = listOfRepo.split(";"); + for (String repo : repos) { + String[] parts = repo.split(","); + if (parts.length == 3) { + String id = parts[0].trim(); + String url = parts[1].trim(); + boolean isSnapshot = Boolean.parseBoolean(parts[2].trim()); + if (id.length() > 1 && url.length() > 1) { + RemoteRepository rr = new RemoteRepository(id, "default", url); + rr.setPolicy(isSnapshot, null); + additionalRepos.add(rr); + } + } + } + } + } + + /** + * fetch all artifacts + * @return + * @throws MalformedURLException + * @throws ArtifactResolutionException + * @throws DependencyResolutionException + */ + public List<File> fetch() throws MalformedURLException, + DependencyResolutionException, ArtifactResolutionException { + + for (Dependency dep : dependencies) { + if (!dep.isLocalFsArtifact()) { + List<ArtifactResult> artifacts = fetchArtifactWithDep(dep); + for (ArtifactResult artifact : artifacts) { + if (dep.isDist()) { + filesDist.add(artifact.getArtifact().getFile()); + } + files.add(artifact.getArtifact().getFile()); + } + } else { + if (dep.isDist()) { + filesDist.add(new File(dep.getGroupArtifactVersion())); + } + files.add(new File(dep.getGroupArtifactVersion())); + } + } + + return files; + } + + private List<ArtifactResult> fetchArtifactWithDep(Dependency dep) + throws DependencyResolutionException, ArtifactResolutionException { + Artifact artifact = new DefaultArtifact( + SparkDependencyResolver.inferScalaVersion(dep.getGroupArtifactVersion())); + + DependencyFilter classpathFlter = DependencyFilterUtils + .classpathFilter(JavaScopes.COMPILE); + PatternExclusionsDependencyFilter exclusionFilter = new PatternExclusionsDependencyFilter( + SparkDependencyResolver.inferScalaVersion(dep.getExclusions())); + + CollectRequest collectRequest = new CollectRequest(); + collectRequest.setRoot(new org.sonatype.aether.graph.Dependency(artifact, + JavaScopes.COMPILE)); + + collectRequest.addRepository(mavenCentral); + collectRequest.addRepository(mavenLocal); + for (RemoteRepository repo : additionalRepos) { + collectRequest.addRepository(repo); + } + for (Repository repo : repositories) { + RemoteRepository rr = new RemoteRepository(repo.getId(), "default", repo.getUrl()); + rr.setPolicy(repo.isSnapshot(), null); + Authentication auth = repo.getAuthentication(); + if (auth != null) { + rr.setAuthentication(auth); + } + collectRequest.addRepository(rr); + } + + DependencyRequest dependencyRequest = new DependencyRequest(collectRequest, + DependencyFilterUtils.andFilter(exclusionFilter, classpathFlter)); + + return system.resolveDependencies(session, dependencyRequest).getArtifactResults(); + } + + public List<File> getFiles() { + return files; + } + + public List<File> getFilesDist() { + return filesDist; + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyResolver.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyResolver.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyResolver.java new file mode 100644 index 0000000..46224a8 --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyResolver.java @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark.dep; + +import java.io.File; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.URL; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.commons.lang.StringUtils; +import org.apache.spark.SparkContext; +import org.apache.zeppelin.dep.AbstractDependencyResolver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.sonatype.aether.artifact.Artifact; +import org.sonatype.aether.collection.CollectRequest; +import org.sonatype.aether.graph.Dependency; +import org.sonatype.aether.graph.DependencyFilter; +import org.sonatype.aether.repository.RemoteRepository; +import org.sonatype.aether.resolution.ArtifactResult; +import org.sonatype.aether.resolution.DependencyRequest; +import org.sonatype.aether.util.artifact.DefaultArtifact; +import org.sonatype.aether.util.artifact.JavaScopes; +import org.sonatype.aether.util.filter.DependencyFilterUtils; +import org.sonatype.aether.util.filter.PatternExclusionsDependencyFilter; + +import scala.Some; +import scala.collection.IndexedSeq; +import scala.reflect.io.AbstractFile; +import scala.tools.nsc.Global; +import scala.tools.nsc.backend.JavaPlatform; +import scala.tools.nsc.util.ClassPath; +import scala.tools.nsc.util.MergedClassPath; + +/** + * Deps resolver. + * Add new dependencies from mvn repo (at runtime) to Spark interpreter group. + */ +public class SparkDependencyResolver extends AbstractDependencyResolver { + Logger logger = LoggerFactory.getLogger(SparkDependencyResolver.class); + private Global global; + private ClassLoader runtimeClassLoader; + private SparkContext sc; + + private final String[] exclusions = new String[] {"org.scala-lang:scala-library", + "org.scala-lang:scala-compiler", + "org.scala-lang:scala-reflect", + "org.scala-lang:scalap", + "org.apache.zeppelin:zeppelin-zengine", + "org.apache.zeppelin:zeppelin-spark", + "org.apache.zeppelin:zeppelin-server"}; + + public SparkDependencyResolver(Global global, + ClassLoader runtimeClassLoader, + SparkContext sc, + String localRepoPath, + String additionalRemoteRepository) { + super(localRepoPath); + this.global = global; + this.runtimeClassLoader = runtimeClassLoader; + this.sc = sc; + addRepoFromProperty(additionalRemoteRepository); + } + + private void addRepoFromProperty(String listOfRepo) { + if (listOfRepo != null) { + String[] repos = listOfRepo.split(";"); + for (String repo : repos) { + String[] parts = repo.split(","); + if (parts.length == 3) { + String id = parts[0].trim(); + String url = parts[1].trim(); + boolean isSnapshot = Boolean.parseBoolean(parts[2].trim()); + if (id.length() > 1 && url.length() > 1) { + addRepo(id, url, isSnapshot); + } + } + } + } + } + + private void updateCompilerClassPath(URL[] urls) throws IllegalAccessException, + IllegalArgumentException, InvocationTargetException { + + JavaPlatform platform = (JavaPlatform) global.platform(); + MergedClassPath<AbstractFile> newClassPath = mergeUrlsIntoClassPath(platform, urls); + + Method[] methods = platform.getClass().getMethods(); + for (Method m : methods) { + if (m.getName().endsWith("currentClassPath_$eq")) { + m.invoke(platform, new Some(newClassPath)); + break; + } + } + + // NOTE: Must use reflection until this is exposed/fixed upstream in Scala + List<String> classPaths = new LinkedList<>(); + for (URL url : urls) { + classPaths.add(url.getPath()); + } + + // Reload all jars specified into our compiler + global.invalidateClassPathEntries(scala.collection.JavaConversions.asScalaBuffer(classPaths) + .toList()); + } + + // Until spark 1.1.x + // check https://github.com/apache/spark/commit/191d7cf2a655d032f160b9fa181730364681d0e7 + private void updateRuntimeClassPath_1_x(URL[] urls) throws SecurityException, + IllegalAccessException, IllegalArgumentException, + InvocationTargetException, NoSuchMethodException { + Method addURL; + addURL = runtimeClassLoader.getClass().getDeclaredMethod("addURL", new Class[] {URL.class}); + addURL.setAccessible(true); + for (URL url : urls) { + addURL.invoke(runtimeClassLoader, url); + } + } + + private void updateRuntimeClassPath_2_x(URL[] urls) throws SecurityException, + IllegalAccessException, IllegalArgumentException, + InvocationTargetException, NoSuchMethodException { + Method addURL; + addURL = runtimeClassLoader.getClass().getDeclaredMethod("addNewUrl", new Class[] {URL.class}); + addURL.setAccessible(true); + for (URL url : urls) { + addURL.invoke(runtimeClassLoader, url); + } + } + + private MergedClassPath<AbstractFile> mergeUrlsIntoClassPath(JavaPlatform platform, URL[] urls) { + IndexedSeq<ClassPath<AbstractFile>> entries = + ((MergedClassPath<AbstractFile>) platform.classPath()).entries(); + List<ClassPath<AbstractFile>> cp = new LinkedList<>(); + + for (int i = 0; i < entries.size(); i++) { + cp.add(entries.apply(i)); + } + + for (URL url : urls) { + AbstractFile file; + if ("file".equals(url.getProtocol())) { + File f = new File(url.getPath()); + if (f.isDirectory()) { + file = AbstractFile.getDirectory(scala.reflect.io.File.jfile2path(f)); + } else { + file = AbstractFile.getFile(scala.reflect.io.File.jfile2path(f)); + } + } else { + file = AbstractFile.getURL(url); + } + + ClassPath<AbstractFile> newcp = platform.classPath().context().newClassPath(file); + + // distinct + if (cp.contains(newcp) == false) { + cp.add(newcp); + } + } + + return new MergedClassPath(scala.collection.JavaConversions.asScalaBuffer(cp).toIndexedSeq(), + platform.classPath().context()); + } + + public List<String> load(String artifact, + boolean addSparkContext) throws Exception { + return load(artifact, new LinkedList<String>(), addSparkContext); + } + + public List<String> load(String artifact, Collection<String> excludes, + boolean addSparkContext) throws Exception { + if (StringUtils.isBlank(artifact)) { + // Should throw here + throw new RuntimeException("Invalid artifact to load"); + } + + // <groupId>:<artifactId>[:<extension>[:<classifier>]]:<version> + int numSplits = artifact.split(":").length; + if (numSplits >= 3 && numSplits <= 6) { + return loadFromMvn(artifact, excludes, addSparkContext); + } else { + loadFromFs(artifact, addSparkContext); + LinkedList<String> libs = new LinkedList<>(); + libs.add(artifact); + return libs; + } + } + + private void loadFromFs(String artifact, boolean addSparkContext) throws Exception { + File jarFile = new File(artifact); + + global.new Run(); + + if (sc.version().startsWith("1.1")) { + updateRuntimeClassPath_1_x(new URL[] {jarFile.toURI().toURL()}); + } else { + updateRuntimeClassPath_2_x(new URL[] {jarFile.toURI().toURL()}); + } + + if (addSparkContext) { + sc.addJar(jarFile.getAbsolutePath()); + } + } + + private List<String> loadFromMvn(String artifact, Collection<String> excludes, + boolean addSparkContext) throws Exception { + List<String> loadedLibs = new LinkedList<>(); + Collection<String> allExclusions = new LinkedList<>(); + allExclusions.addAll(excludes); + allExclusions.addAll(Arrays.asList(exclusions)); + + List<ArtifactResult> listOfArtifact; + listOfArtifact = getArtifactsWithDep(artifact, allExclusions); + + Iterator<ArtifactResult> it = listOfArtifact.iterator(); + while (it.hasNext()) { + Artifact a = it.next().getArtifact(); + String gav = a.getGroupId() + ":" + a.getArtifactId() + ":" + a.getVersion(); + for (String exclude : allExclusions) { + if (gav.startsWith(exclude)) { + it.remove(); + break; + } + } + } + + List<URL> newClassPathList = new LinkedList<>(); + List<File> files = new LinkedList<>(); + for (ArtifactResult artifactResult : listOfArtifact) { + logger.info("Load " + artifactResult.getArtifact().getGroupId() + ":" + + artifactResult.getArtifact().getArtifactId() + ":" + + artifactResult.getArtifact().getVersion()); + newClassPathList.add(artifactResult.getArtifact().getFile().toURI().toURL()); + files.add(artifactResult.getArtifact().getFile()); + loadedLibs.add(artifactResult.getArtifact().getGroupId() + ":" + + artifactResult.getArtifact().getArtifactId() + ":" + + artifactResult.getArtifact().getVersion()); + } + + global.new Run(); + if (sc.version().startsWith("1.1")) { + updateRuntimeClassPath_1_x(newClassPathList.toArray(new URL[0])); + } else { + updateRuntimeClassPath_2_x(newClassPathList.toArray(new URL[0])); + } + updateCompilerClassPath(newClassPathList.toArray(new URL[0])); + + if (addSparkContext) { + for (File f : files) { + sc.addJar(f.getAbsolutePath()); + } + } + + return loadedLibs; + } + + /** + * @param dependency + * @param excludes list of pattern can either be of the form groupId:artifactId + * @return + * @throws Exception + */ + @Override + public List<ArtifactResult> getArtifactsWithDep(String dependency, + Collection<String> excludes) throws Exception { + Artifact artifact = new DefaultArtifact(inferScalaVersion(dependency)); + DependencyFilter classpathFilter = DependencyFilterUtils.classpathFilter(JavaScopes.COMPILE); + PatternExclusionsDependencyFilter exclusionFilter = + new PatternExclusionsDependencyFilter(inferScalaVersion(excludes)); + + CollectRequest collectRequest = new CollectRequest(); + collectRequest.setRoot(new Dependency(artifact, JavaScopes.COMPILE)); + + synchronized (repos) { + for (RemoteRepository repo : repos) { + collectRequest.addRepository(repo); + } + } + DependencyRequest dependencyRequest = new DependencyRequest(collectRequest, + DependencyFilterUtils.andFilter(exclusionFilter, classpathFilter)); + return system.resolveDependencies(session, dependencyRequest).getArtifactResults(); + } + + public static Collection<String> inferScalaVersion(Collection<String> artifact) { + List<String> list = new LinkedList<>(); + for (String a : artifact) { + list.add(inferScalaVersion(a)); + } + return list; + } + + public static String inferScalaVersion(String artifact) { + int pos = artifact.indexOf(":"); + if (pos < 0 || pos + 2 >= artifact.length()) { + // failed to infer + return artifact; + } + + if (':' == artifact.charAt(pos + 1)) { + String restOfthem = ""; + String versionSep = ":"; + + String groupId = artifact.substring(0, pos); + int nextPos = artifact.indexOf(":", pos + 2); + if (nextPos < 0) { + if (artifact.charAt(artifact.length() - 1) == '*') { + nextPos = artifact.length() - 1; + versionSep = ""; + restOfthem = "*"; + } else { + versionSep = ""; + nextPos = artifact.length(); + } + } + + String artifactId = artifact.substring(pos + 2, nextPos); + if (nextPos < artifact.length()) { + if (!restOfthem.equals("*")) { + restOfthem = artifact.substring(nextPos + 1); + } + } + + String [] version = scala.util.Properties.versionNumberString().split("[.]"); + String scalaVersion = version[0] + "." + version[1]; + + return groupId + ":" + artifactId + "_" + scalaVersion + versionSep + restOfthem; + } else { + return artifact; + } + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/main/resources/R/zeppelin_sparkr.R ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/resources/R/zeppelin_sparkr.R b/spark/interpreter/src/main/resources/R/zeppelin_sparkr.R new file mode 100644 index 0000000..525c6c5 --- /dev/null +++ b/spark/interpreter/src/main/resources/R/zeppelin_sparkr.R @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +args <- commandArgs(trailingOnly = TRUE) + +hashCode <- as.integer(args[1]) +port <- as.integer(args[2]) +libPath <- args[3] +version <- as.integer(args[4]) +rm(args) + +print(paste("Port ", toString(port))) +print(paste("LibPath ", libPath)) + +.libPaths(c(file.path(libPath), .libPaths())) +library(SparkR) + + +SparkR:::connectBackend("localhost", port, 6000) + +# scStartTime is needed by R/pkg/R/sparkR.R +assign(".scStartTime", as.integer(Sys.time()), envir = SparkR:::.sparkREnv) + +# getZeppelinR +.zeppelinR = SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinR", "getZeppelinR", hashCode) + +# setup spark env +assign(".sc", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getSparkContext"), envir = SparkR:::.sparkREnv) +assign("sc", get(".sc", envir = SparkR:::.sparkREnv), envir=.GlobalEnv) +if (version >= 20000) { + assign(".sparkRsession", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getSparkSession"), envir = SparkR:::.sparkREnv) + assign("spark", get(".sparkRsession", envir = SparkR:::.sparkREnv), envir = .GlobalEnv) + assign(".sparkRjsc", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getJavaSparkContext"), envir = SparkR:::.sparkREnv) +} +assign(".sqlc", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getSqlContext"), envir = SparkR:::.sparkREnv) +assign("sqlContext", get(".sqlc", envir = SparkR:::.sparkREnv), envir = .GlobalEnv) +assign(".zeppelinContext", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getZeppelinContext"), envir = .GlobalEnv) + +z.put <- function(name, object) { + SparkR:::callJMethod(.zeppelinContext, "put", name, object) +} +z.get <- function(name) { + SparkR:::callJMethod(.zeppelinContext, "get", name) +} +z.input <- function(name, value) { + SparkR:::callJMethod(.zeppelinContext, "input", name, value) +} + +# notify script is initialized +SparkR:::callJMethod(.zeppelinR, "onScriptInitialized") + +while (TRUE) { + req <- SparkR:::callJMethod(.zeppelinR, "getRequest") + type <- SparkR:::callJMethod(req, "getType") + stmt <- SparkR:::callJMethod(req, "getStmt") + value <- SparkR:::callJMethod(req, "getValue") + + if (type == "eval") { + tryCatch({ + ret <- eval(parse(text=stmt)) + SparkR:::callJMethod(.zeppelinR, "setResponse", "", FALSE) + }, error = function(e) { + SparkR:::callJMethod(.zeppelinR, "setResponse", toString(e), TRUE) + }) + } else if (type == "set") { + tryCatch({ + ret <- assign(stmt, value) + SparkR:::callJMethod(.zeppelinR, "setResponse", "", FALSE) + }, error = function(e) { + SparkR:::callJMethod(.zeppelinR, "setResponse", toString(e), TRUE) + }) + } else if (type == "get") { + tryCatch({ + ret <- eval(parse(text=stmt)) + SparkR:::callJMethod(.zeppelinR, "setResponse", ret, FALSE) + }, error = function(e) { + SparkR:::callJMethod(.zeppelinR, "setResponse", toString(e), TRUE) + }) + } else if (type == "getS") { + tryCatch({ + ret <- eval(parse(text=stmt)) + SparkR:::callJMethod(.zeppelinR, "setResponse", toString(ret), FALSE) + }, error = function(e) { + SparkR:::callJMethod(.zeppelinR, "setResponse", toString(e), TRUE) + }) + } else { + # unsupported type + SparkR:::callJMethod(.zeppelinR, "setResponse", paste("Unsupported type ", type), TRUE) + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/main/resources/interpreter-setting.json ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/resources/interpreter-setting.json b/spark/interpreter/src/main/resources/interpreter-setting.json new file mode 100644 index 0000000..7e647d7 --- /dev/null +++ b/spark/interpreter/src/main/resources/interpreter-setting.json @@ -0,0 +1,233 @@ +[ + { + "group": "spark", + "name": "spark", + "className": "org.apache.zeppelin.spark.SparkInterpreter", + "defaultInterpreter": true, + "properties": { + "spark.executor.memory": { + "envName": null, + "propertyName": "spark.executor.memory", + "defaultValue": "", + "description": "Executor memory per worker instance. ex) 512m, 32g", + "type": "string" + }, + "args": { + "envName": null, + "propertyName": null, + "defaultValue": "", + "description": "spark commandline args", + "type": "textarea" + }, + "zeppelin.spark.useHiveContext": { + "envName": "ZEPPELIN_SPARK_USEHIVECONTEXT", + "propertyName": "zeppelin.spark.useHiveContext", + "defaultValue": true, + "description": "Use HiveContext instead of SQLContext if it is true.", + "type": "checkbox" + }, + "spark.app.name": { + "envName": "SPARK_APP_NAME", + "propertyName": "spark.app.name", + "defaultValue": "Zeppelin", + "description": "The name of spark application.", + "type": "string" + }, + "zeppelin.spark.printREPLOutput": { + "envName": null, + "propertyName": "zeppelin.spark.printREPLOutput", + "defaultValue": true, + "description": "Print REPL output", + "type": "checkbox" + }, + "spark.cores.max": { + "envName": null, + "propertyName": "spark.cores.max", + "defaultValue": "", + "description": "Total number of cores to use. Empty value uses all available core.", + "type": "number" + }, + "zeppelin.spark.maxResult": { + "envName": "ZEPPELIN_SPARK_MAXRESULT", + "propertyName": "zeppelin.spark.maxResult", + "defaultValue": "1000", + "description": "Max number of Spark SQL result to display.", + "type": "number" + }, + "master": { + "envName": "MASTER", + "propertyName": "spark.master", + "defaultValue": "local[*]", + "description": "Spark master uri. ex) spark://masterhost:7077", + "type": "string" + }, + "zeppelin.spark.enableSupportedVersionCheck": { + "envName": null, + "propertyName": "zeppelin.spark.enableSupportedVersionCheck", + "defaultValue": true, + "description": "Do not change - developer only setting, not for production use", + "type": "checkbox" + }, + "zeppelin.spark.uiWebUrl": { + "envName": null, + "propertyName": "zeppelin.spark.uiWebUrl", + "defaultValue": "", + "description": "Override Spark UI default URL", + "type": "string" + }, + "zeppelin.spark.useNew": { + "envName": null, + "propertyName": "zeppelin.spark.useNew", + "defaultValue": "false", + "description": "Whether use new spark interpreter implementation", + "type": "checkbox" + } + }, + "editor": { + "language": "scala", + "editOnDblClick": false, + "completionKey": "TAB" + } + }, + { + "group": "spark", + "name": "sql", + "className": "org.apache.zeppelin.spark.SparkSqlInterpreter", + "properties": { + "zeppelin.spark.concurrentSQL": { + "envName": "ZEPPELIN_SPARK_CONCURRENTSQL", + "propertyName": "zeppelin.spark.concurrentSQL", + "defaultValue": false, + "description": "Execute multiple SQL concurrently if set true.", + "type": "checkbox" + }, + "zeppelin.spark.sql.stacktrace": { + "envName": "ZEPPELIN_SPARK_SQL_STACKTRACE", + "propertyName": "zeppelin.spark.sql.stacktrace", + "defaultValue": false, + "description": "Show full exception stacktrace for SQL queries if set to true.", + "type": "checkbox" + }, + "zeppelin.spark.maxResult": { + "envName": "ZEPPELIN_SPARK_MAXRESULT", + "propertyName": "zeppelin.spark.maxResult", + "defaultValue": "1000", + "description": "Max number of Spark SQL result to display.", + "type": "number" + }, + "zeppelin.spark.importImplicit": { + "envName": "ZEPPELIN_SPARK_IMPORTIMPLICIT", + "propertyName": "zeppelin.spark.importImplicit", + "defaultValue": true, + "description": "Import implicits, UDF collection, and sql if set true. true by default.", + "type": "checkbox" + } + }, + "editor": { + "language": "sql", + "editOnDblClick": false, + "completionKey": "TAB" + } + }, + { + "group": "spark", + "name": "dep", + "className": "org.apache.zeppelin.spark.DepInterpreter", + "properties": { + "zeppelin.dep.localrepo": { + "envName": "ZEPPELIN_DEP_LOCALREPO", + "propertyName": null, + "defaultValue": "local-repo", + "description": "local repository for dependency loader", + "type": "string" + }, + "zeppelin.dep.additionalRemoteRepository": { + "envName": null, + "propertyName": null, + "defaultValue": "spark-packages,http://dl.bintray.com/spark-packages/maven,false;", + "description": "A list of 'id,remote-repository-URL,is-snapshot;' for each remote repository.", + "type": "textarea" + } + }, + "editor": { + "language": "scala", + "editOnDblClick": false, + "completionKey": "TAB" + } + }, + { + "group": "spark", + "name": "pyspark", + "className": "org.apache.zeppelin.spark.PySparkInterpreter", + "properties": { + "zeppelin.pyspark.python": { + "envName": "PYSPARK_PYTHON", + "propertyName": null, + "defaultValue": "python", + "description": "Python command to run pyspark with", + "type": "string" + }, + "zeppelin.pyspark.useIPython": { + "envName": null, + "propertyName": "zeppelin.pyspark.useIPython", + "defaultValue": true, + "description": "whether use IPython when it is available", + "type": "checkbox" + } + }, + "editor": { + "language": "python", + "editOnDblClick": false, + "completionKey": "TAB" + } + }, + { + "group": "spark", + "name": "ipyspark", + "className": "org.apache.zeppelin.spark.IPySparkInterpreter", + "properties": {}, + "editor": { + "language": "python", + "editOnDblClick": false + } + }, + { + "group": "spark", + "name": "r", + "className": "org.apache.zeppelin.spark.SparkRInterpreter", + "properties": { + "zeppelin.R.knitr": { + "envName": "ZEPPELIN_R_KNITR", + "propertyName": "zeppelin.R.knitr", + "defaultValue": true, + "description": "whether use knitr or not", + "type": "checkbox" + }, + "zeppelin.R.cmd": { + "envName": "ZEPPELIN_R_CMD", + "propertyName": "zeppelin.R.cmd", + "defaultValue": "R", + "description": "R repl path", + "type": "string" + }, + "zeppelin.R.image.width": { + "envName": "ZEPPELIN_R_IMAGE_WIDTH", + "propertyName": "zeppelin.R.image.width", + "defaultValue": "100%", + "description": "", + "type": "number" + }, + "zeppelin.R.render.options": { + "envName": "ZEPPELIN_R_RENDER_OPTIONS", + "propertyName": "zeppelin.R.render.options", + "defaultValue": "out.format = 'html', comment = NA, echo = FALSE, results = 'asis', message = F, warning = F, fig.retina = 2", + "description": "", + "type": "textarea" + } + }, + "editor": { + "language": "r", + "editOnDblClick": false + } + } +] http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/main/resources/python/zeppelin_ipyspark.py ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/resources/python/zeppelin_ipyspark.py b/spark/interpreter/src/main/resources/python/zeppelin_ipyspark.py new file mode 100644 index 0000000..324f481 --- /dev/null +++ b/spark/interpreter/src/main/resources/python/zeppelin_ipyspark.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from pyspark.conf import SparkConf +from pyspark.context import SparkContext + +# for back compatibility +from pyspark.sql import SQLContext + +# start JVM gateway +client = GatewayClient(port=${JVM_GATEWAY_PORT}) +gateway = JavaGateway(client, auto_convert=True) + +java_import(gateway.jvm, "org.apache.spark.SparkEnv") +java_import(gateway.jvm, "org.apache.spark.SparkConf") +java_import(gateway.jvm, "org.apache.spark.api.java.*") +java_import(gateway.jvm, "org.apache.spark.api.python.*") +java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") + +intp = gateway.entry_point +jsc = intp.getJavaSparkContext() + +java_import(gateway.jvm, "org.apache.spark.sql.*") +java_import(gateway.jvm, "org.apache.spark.sql.hive.*") +java_import(gateway.jvm, "scala.Tuple2") + +jconf = jsc.getConf() +conf = SparkConf(_jvm=gateway.jvm, _jconf=jconf) +sc = _zsc_ = SparkContext(jsc=jsc, gateway=gateway, conf=conf) + +if intp.isSpark2(): + from pyspark.sql import SparkSession + + spark = __zSpark__ = SparkSession(sc, intp.getSparkSession()) + sqlContext = sqlc = __zSqlc__ = __zSpark__._wrapped +else: + sqlContext = sqlc = __zSqlc__ = SQLContext(sparkContext=sc, sqlContext=intp.getSQLContext()) http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py new file mode 100644 index 0000000..c10855a --- /dev/null +++ b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py @@ -0,0 +1,393 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os, sys, getopt, traceback, json, re + +from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.protocol import Py4JJavaError +from pyspark.conf import SparkConf +from pyspark.context import SparkContext +import ast +import warnings + +# for back compatibility +from pyspark.sql import SQLContext, HiveContext, Row + +class Logger(object): + def __init__(self): + pass + + def write(self, message): + intp.appendOutput(message) + + def reset(self): + pass + + def flush(self): + pass + + +class PyZeppelinContext(dict): + def __init__(self, zc): + self.z = zc + self._displayhook = lambda *args: None + + def show(self, obj): + from pyspark.sql import DataFrame + if isinstance(obj, DataFrame): + print(self.z.showData(obj._jdf)) + else: + print(str(obj)) + + # By implementing special methods it makes operating on it more Pythonic + def __setitem__(self, key, item): + self.z.put(key, item) + + def __getitem__(self, key): + return self.z.get(key) + + def __delitem__(self, key): + self.z.remove(key) + + def __contains__(self, item): + return self.z.containsKey(item) + + def add(self, key, value): + self.__setitem__(key, value) + + def put(self, key, value): + self.__setitem__(key, value) + + def get(self, key): + return self.__getitem__(key) + + def getInterpreterContext(self): + return self.z.getInterpreterContext() + + def input(self, name, defaultValue=""): + return self.z.input(name, defaultValue) + + def textbox(self, name, defaultValue=""): + return self.z.textbox(name, defaultValue) + + def noteTextbox(self, name, defaultValue=""): + return self.z.noteTextbox(name, defaultValue) + + def select(self, name, options, defaultValue=""): + # auto_convert to ArrayList doesn't match the method signature on JVM side + return self.z.select(name, defaultValue, self.getParamOptions(options)) + + def noteSelect(self, name, options, defaultValue=""): + return self.z.noteSelect(name, defaultValue, self.getParamOptions(options)) + + def checkbox(self, name, options, defaultChecked=None): + optionsIterable = self.getParamOptions(options) + defaultCheckedIterables = self.getDefaultChecked(defaultChecked) + checkedItems = gateway.jvm.scala.collection.JavaConversions.seqAsJavaList(self.z.checkbox(name, defaultCheckedIterables, optionsIterable)) + result = [] + for checkedItem in checkedItems: + result.append(checkedItem) + return result; + + def noteCheckbox(self, name, options, defaultChecked=None): + optionsIterable = self.getParamOptions(options) + defaultCheckedIterables = self.getDefaultChecked(defaultChecked) + checkedItems = gateway.jvm.scala.collection.JavaConversions.seqAsJavaList(self.z.noteCheckbox(name, defaultCheckedIterables, optionsIterable)) + result = [] + for checkedItem in checkedItems: + result.append(checkedItem) + return result; + + def getParamOptions(self, options): + tuples = list(map(lambda items: self.__tupleToScalaTuple2(items), options)) + return gateway.jvm.scala.collection.JavaConversions.collectionAsScalaIterable(tuples) + + def getDefaultChecked(self, defaultChecked): + if defaultChecked is None: + defaultChecked = [] + return gateway.jvm.scala.collection.JavaConversions.collectionAsScalaIterable(defaultChecked) + + def registerHook(self, event, cmd, replName=None): + if replName is None: + self.z.registerHook(event, cmd) + else: + self.z.registerHook(event, cmd, replName) + + def unregisterHook(self, event, replName=None): + if replName is None: + self.z.unregisterHook(event) + else: + self.z.unregisterHook(event, replName) + + def getHook(self, event, replName=None): + if replName is None: + return self.z.getHook(event) + return self.z.getHook(event, replName) + + def _setup_matplotlib(self): + # If we don't have matplotlib installed don't bother continuing + try: + import matplotlib + except ImportError: + return + + # Make sure custom backends are available in the PYTHONPATH + rootdir = os.environ.get('ZEPPELIN_HOME', os.getcwd()) + mpl_path = os.path.join(rootdir, 'interpreter', 'lib', 'python') + if mpl_path not in sys.path: + sys.path.append(mpl_path) + + # Finally check if backend exists, and if so configure as appropriate + try: + matplotlib.use('module://backend_zinline') + import backend_zinline + + # Everything looks good so make config assuming that we are using + # an inline backend + self._displayhook = backend_zinline.displayhook + self.configure_mpl(width=600, height=400, dpi=72, fontsize=10, + interactive=True, format='png', context=self.z) + except ImportError: + # Fall back to Agg if no custom backend installed + matplotlib.use('Agg') + warnings.warn("Unable to load inline matplotlib backend, " + "falling back to Agg") + + def configure_mpl(self, **kwargs): + import mpl_config + mpl_config.configure(**kwargs) + + def __tupleToScalaTuple2(self, tuple): + if (len(tuple) == 2): + return gateway.jvm.scala.Tuple2(tuple[0], tuple[1]) + else: + raise IndexError("options must be a list of tuple of 2") + + +class SparkVersion(object): + SPARK_1_4_0 = 10400 + SPARK_1_3_0 = 10300 + SPARK_2_0_0 = 20000 + + def __init__(self, versionNumber): + self.version = versionNumber + + def isAutoConvertEnabled(self): + return self.version >= self.SPARK_1_4_0 + + def isImportAllPackageUnderSparkSql(self): + return self.version >= self.SPARK_1_3_0 + + def isSpark2(self): + return self.version >= self.SPARK_2_0_0 + +class PySparkCompletion: + def __init__(self, interpreterObject): + self.interpreterObject = interpreterObject + + def getGlobalCompletion(self): + objectDefList = [] + try: + for completionItem in list(globals().keys()): + objectDefList.append(completionItem) + except: + return None + else: + return objectDefList + + def getMethodCompletion(self, text_value): + execResult = locals() + if text_value == None: + return None + completion_target = text_value + try: + if len(completion_target) <= 0: + return None + if text_value[-1] == ".": + completion_target = text_value[:-1] + exec("{} = dir({})".format("objectDefList", completion_target), globals(), execResult) + except: + return None + else: + return list(execResult['objectDefList']) + + + def getCompletion(self, text_value): + completionList = set() + + globalCompletionList = self.getGlobalCompletion() + if globalCompletionList != None: + for completionItem in list(globalCompletionList): + completionList.add(completionItem) + + if text_value != None: + objectCompletionList = self.getMethodCompletion(text_value) + if objectCompletionList != None: + for completionItem in list(objectCompletionList): + completionList.add(completionItem) + if len(completionList) <= 0: + self.interpreterObject.setStatementsFinished("", False) + else: + result = json.dumps(list(filter(lambda x : not re.match("^__.*", x), list(completionList)))) + self.interpreterObject.setStatementsFinished(result, False) + +client = GatewayClient(port=int(sys.argv[1])) +sparkVersion = SparkVersion(int(sys.argv[2])) +if sparkVersion.isSpark2(): + from pyspark.sql import SparkSession +else: + from pyspark.sql import SchemaRDD + +if sparkVersion.isAutoConvertEnabled(): + gateway = JavaGateway(client, auto_convert = True) +else: + gateway = JavaGateway(client) + +java_import(gateway.jvm, "org.apache.spark.SparkEnv") +java_import(gateway.jvm, "org.apache.spark.SparkConf") +java_import(gateway.jvm, "org.apache.spark.api.java.*") +java_import(gateway.jvm, "org.apache.spark.api.python.*") +java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") + +intp = gateway.entry_point +output = Logger() +sys.stdout = output +sys.stderr = output +intp.onPythonScriptInitialized(os.getpid()) + +jsc = intp.getJavaSparkContext() + +if sparkVersion.isImportAllPackageUnderSparkSql(): + java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.hive.*") +else: + java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") + java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") + java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") + java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") + + +java_import(gateway.jvm, "scala.Tuple2") + +_zcUserQueryNameSpace = {} + +jconf = intp.getSparkConf() +conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf) +sc = _zsc_ = SparkContext(jsc=jsc, gateway=gateway, conf=conf) +_zcUserQueryNameSpace["_zsc_"] = _zsc_ +_zcUserQueryNameSpace["sc"] = sc + +if sparkVersion.isSpark2(): + spark = __zSpark__ = SparkSession(sc, intp.getSparkSession()) + sqlc = __zSqlc__ = __zSpark__._wrapped + _zcUserQueryNameSpace["sqlc"] = sqlc + _zcUserQueryNameSpace["__zSqlc__"] = __zSqlc__ + _zcUserQueryNameSpace["spark"] = spark + _zcUserQueryNameSpace["__zSpark__"] = __zSpark__ +else: + sqlc = __zSqlc__ = SQLContext(sparkContext=sc, sqlContext=intp.getSQLContext()) + _zcUserQueryNameSpace["sqlc"] = sqlc + _zcUserQueryNameSpace["__zSqlc__"] = sqlc + +sqlContext = __zSqlc__ +_zcUserQueryNameSpace["sqlContext"] = sqlContext + +completion = __zeppelin_completion__ = PySparkCompletion(intp) +_zcUserQueryNameSpace["completion"] = completion +_zcUserQueryNameSpace["__zeppelin_completion__"] = __zeppelin_completion__ + +z = __zeppelin__ = PyZeppelinContext(intp.getZeppelinContext()) +__zeppelin__._setup_matplotlib() +_zcUserQueryNameSpace["z"] = z +_zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__ + +while True : + req = intp.getStatements() + try: + stmts = req.statements().split("\n") + jobGroup = req.jobGroup() + jobDesc = req.jobDescription() + + # Get post-execute hooks + try: + global_hook = intp.getHook('post_exec_dev') + except: + global_hook = None + + try: + user_hook = __zeppelin__.getHook('post_exec') + except: + user_hook = None + + nhooks = 0 + for hook in (global_hook, user_hook): + if hook: + nhooks += 1 + + if stmts: + # use exec mode to compile the statements except the last statement, + # so that the last statement's evaluation will be printed to stdout + sc.setJobGroup(jobGroup, jobDesc) + code = compile('\n'.join(stmts), '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1) + to_run_hooks = [] + if (nhooks > 0): + to_run_hooks = code.body[-nhooks:] + to_run_exec, to_run_single = (code.body[:-(nhooks + 1)], + [code.body[-(nhooks + 1)]]) + + try: + for node in to_run_exec: + mod = ast.Module([node]) + code = compile(mod, '<stdin>', 'exec') + exec(code, _zcUserQueryNameSpace) + + for node in to_run_single: + mod = ast.Interactive([node]) + code = compile(mod, '<stdin>', 'single') + exec(code, _zcUserQueryNameSpace) + + for node in to_run_hooks: + mod = ast.Module([node]) + code = compile(mod, '<stdin>', 'exec') + exec(code, _zcUserQueryNameSpace) + + intp.setStatementsFinished("", False) + except Py4JJavaError: + # raise it to outside try except + raise + except: + exception = traceback.format_exc() + m = re.search("File \"<stdin>\", line (\d+).*", exception) + if m: + line_no = int(m.group(1)) + intp.setStatementsFinished( + "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True) + else: + intp.setStatementsFinished(exception, True) + else: + intp.setStatementsFinished("", False) + + except Py4JJavaError: + excInnerError = traceback.format_exc() # format_tb() does not return the inner exception + innerErrorStart = excInnerError.find("Py4JJavaError:") + if innerErrorStart > -1: + excInnerError = excInnerError[innerErrorStart:] + intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True) + except: + intp.setStatementsFinished(traceback.format_exc(), True) + + output.reset() http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/main/scala/org/apache/spark/SparkRBackend.scala ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/scala/org/apache/spark/SparkRBackend.scala b/spark/interpreter/src/main/scala/org/apache/spark/SparkRBackend.scala new file mode 100644 index 0000000..05f1ac0 --- /dev/null +++ b/spark/interpreter/src/main/scala/org/apache/spark/SparkRBackend.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +import org.apache.spark.api.r.RBackend + +object SparkRBackend { + val backend : RBackend = new RBackend() + private var started = false; + private var portNumber = 0; + + val backendThread : Thread = new Thread("SparkRBackend") { + override def run() { + backend.run() + } + } + + def init() : Int = { + portNumber = backend.init() + portNumber + } + + def start() : Unit = { + backendThread.start() + started = true + } + + def close() : Unit = { + backend.close() + backendThread.join() + } + + def isStarted() : Boolean = { + started + } + + def port(): Int = { + return portNumber + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/ZeppelinRDisplay.scala ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/ZeppelinRDisplay.scala b/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/ZeppelinRDisplay.scala new file mode 100644 index 0000000..a9014c2 --- /dev/null +++ b/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/ZeppelinRDisplay.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark + +import org.apache.zeppelin.interpreter.InterpreterResult.Code +import org.apache.zeppelin.interpreter.InterpreterResult.Code.{SUCCESS} +import org.apache.zeppelin.interpreter.InterpreterResult.Type +import org.apache.zeppelin.interpreter.InterpreterResult.Type.{TEXT, HTML, TABLE, IMG} +import org.jsoup.Jsoup +import org.jsoup.nodes.Element +import org.jsoup.nodes.Document.OutputSettings +import org.jsoup.safety.Whitelist + +import scala.collection.JavaConversions._ +import scala.util.matching.Regex + +case class RDisplay(content: String, `type`: Type, code: Code) + +object ZeppelinRDisplay { + + val pattern = new Regex("""^ *\[\d*\] """) + + def render(html: String, imageWidth: String): RDisplay = { + + val document = Jsoup.parse(html) + document.outputSettings().prettyPrint(false) + + val body = document.body() + + if (body.getElementsByTag("p").isEmpty) return RDisplay(body.html(), HTML, SUCCESS) + + val bodyHtml = body.html() + + if (! bodyHtml.contains("<img") + && ! bodyHtml.contains("<script") + && ! bodyHtml.contains("%html ") + && ! bodyHtml.contains("%table ") + && ! bodyHtml.contains("%img ") + ) { + return textDisplay(body) + } + + if (bodyHtml.contains("%table")) { + return tableDisplay(body) + } + + if (bodyHtml.contains("%img")) { + return imgDisplay(body) + } + + return htmlDisplay(body, imageWidth) + } + + private def textDisplay(body: Element): RDisplay = { + // remove HTML tag while preserving whitespaces and newlines + val text = Jsoup.clean(body.html(), "", + Whitelist.none(), new OutputSettings().prettyPrint(false)) + RDisplay(text, TEXT, SUCCESS) + } + + private def tableDisplay(body: Element): RDisplay = { + val p = body.getElementsByTag("p").first().html.replace("â%table " , "").replace("â", "") + val r = (pattern findFirstIn p).getOrElse("") + val table = p.replace(r, "").replace("\\t", "\t").replace("\\n", "\n") + RDisplay(table, TABLE, SUCCESS) + } + + private def imgDisplay(body: Element): RDisplay = { + val p = body.getElementsByTag("p").first().html.replace("â%img " , "").replace("â", "") + val r = (pattern findFirstIn p).getOrElse("") + val img = p.replace(r, "") + RDisplay(img, IMG, SUCCESS) + } + + private def htmlDisplay(body: Element, imageWidth: String): RDisplay = { + var div = new String() + + for (element <- body.children) { + + val eHtml = element.html() + var eOuterHtml = element.outerHtml() + + eOuterHtml = eOuterHtml.replace("â%html " , "").replace("â", "") + + val r = (pattern findFirstIn eHtml).getOrElse("") + + div = div + eOuterHtml.replace(r, "") + } + + val content = div + .replaceAll("src=\"//", "src=\"http://") + .replaceAll("href=\"//", "href=\"http://") + + body.html(content) + + for (image <- body.getElementsByTag("img")) { + image.attr("width", imageWidth) + } + + RDisplay(body.html, HTML, SUCCESS) + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala b/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala new file mode 100644 index 0000000..8181434 --- /dev/null +++ b/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark.utils + +import java.lang.StringBuilder + +import org.apache.spark.rdd.RDD + +import scala.collection.IterableLike + +object DisplayUtils { + + implicit def toDisplayRDDFunctions[T <: Product](rdd: RDD[T]): DisplayRDDFunctions[T] = new DisplayRDDFunctions[T](rdd) + + implicit def toDisplayTraversableFunctions[T <: Product](traversable: Traversable[T]): DisplayTraversableFunctions[T] = new DisplayTraversableFunctions[T](traversable) + + def html(htmlContent: String = "") = s"%html $htmlContent" + + def img64(base64Content: String = "") = s"%img $base64Content" + + def img(url: String) = s"<img src='$url' />" +} + +trait DisplayCollection[T <: Product] { + + def printFormattedData(traversable: Traversable[T], columnLabels: String*): Unit = { + val providedLabelCount: Int = columnLabels.size + var maxColumnCount:Int = 1 + val headers = new StringBuilder("%table ") + + val data = new StringBuilder("") + + traversable.foreach(tuple => { + maxColumnCount = math.max(maxColumnCount,tuple.productArity) + data.append(tuple.productIterator.mkString("\t")).append("\n") + }) + + if (providedLabelCount > maxColumnCount) { + headers.append(columnLabels.take(maxColumnCount).mkString("\t")).append("\n") + } else if (providedLabelCount < maxColumnCount) { + val missingColumnHeaders = ((providedLabelCount+1) to maxColumnCount).foldLeft[String](""){ + (stringAccumulator,index) => if (index==1) s"Column$index" else s"$stringAccumulator\tColumn$index" + } + + headers.append(columnLabels.mkString("\t")).append(missingColumnHeaders).append("\n") + } else { + headers.append(columnLabels.mkString("\t")).append("\n") + } + + headers.append(data) + + print(headers.toString) + } + +} + +class DisplayRDDFunctions[T <: Product] (val rdd: RDD[T]) extends DisplayCollection[T] { + + def display(columnLabels: String*)(implicit sparkMaxResult: SparkMaxResult): Unit = { + printFormattedData(rdd.take(sparkMaxResult.maxResult), columnLabels: _*) + } + + def display(sparkMaxResult:Int, columnLabels: String*): Unit = { + printFormattedData(rdd.take(sparkMaxResult), columnLabels: _*) + } +} + +class DisplayTraversableFunctions[T <: Product] (val traversable: Traversable[T]) extends DisplayCollection[T] { + + def display(columnLabels: String*): Unit = { + printFormattedData(traversable, columnLabels: _*) + } +} + +class SparkMaxResult(val maxResult: Int) extends Serializable http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/java/org/apache/zeppelin/spark/DepInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/DepInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/DepInterpreterTest.java new file mode 100644 index 0000000..e177d49 --- /dev/null +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/DepInterpreterTest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Properties; + +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class DepInterpreterTest { + + @Rule + public TemporaryFolder tmpDir = new TemporaryFolder(); + + private DepInterpreter dep; + private InterpreterContext context; + + private Properties getTestProperties() throws IOException { + Properties p = new Properties(); + p.setProperty("zeppelin.dep.localrepo", tmpDir.newFolder().getAbsolutePath()); + p.setProperty("zeppelin.dep.additionalRemoteRepository", "spark-packages,http://dl.bintray.com/spark-packages/maven,false;"); + return p; + } + + @Before + public void setUp() throws Exception { + Properties p = getTestProperties(); + + dep = new DepInterpreter(p); + dep.open(); + + InterpreterGroup intpGroup = new InterpreterGroup(); + intpGroup.put("note", new LinkedList<Interpreter>()); + intpGroup.get("note").add(new SparkInterpreter(p)); + intpGroup.get("note").add(dep); + dep.setInterpreterGroup(intpGroup); + + context = new InterpreterContext("note", "id", null, "title", "text", new AuthenticationInfo(), + new HashMap<String, Object>(), new GUI(), new GUI(), + new AngularObjectRegistry(intpGroup.getId(), null), + null, + new LinkedList<InterpreterContextRunner>(), null); + } + + @After + public void tearDown() throws Exception { + dep.close(); + } + + @Test + public void testDefault() { + dep.getDependencyContext().reset(); + InterpreterResult ret = dep.interpret("z.load(\"org.apache.commons:commons-csv:1.1\")", context); + assertEquals(Code.SUCCESS, ret.code()); + + assertEquals(1, dep.getDependencyContext().getFiles().size()); + assertEquals(1, dep.getDependencyContext().getFilesDist().size()); + + // Add a test for the spark-packages repo - default in additionalRemoteRepository + ret = dep.interpret("z.load(\"amplab:spark-indexedrdd:0.3\")", context); + assertEquals(Code.SUCCESS, ret.code()); + + // Reset at the end of the test + dep.getDependencyContext().reset(); + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java new file mode 100644 index 0000000..765237c --- /dev/null +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + + +import com.google.common.io.Files; +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResultMessage; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.python.IPythonInterpreterTest; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.net.URL; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class IPySparkInterpreterTest { + + private IPySparkInterpreter iPySparkInterpreter; + private InterpreterGroup intpGroup; + + @Before + public void setup() throws InterpreterException { + Properties p = new Properties(); + p.setProperty("spark.master", "local[4]"); + p.setProperty("master", "local[4]"); + p.setProperty("spark.submit.deployMode", "client"); + p.setProperty("spark.app.name", "Zeppelin Test"); + p.setProperty("zeppelin.spark.useHiveContext", "true"); + p.setProperty("zeppelin.spark.maxResult", "1000"); + p.setProperty("zeppelin.spark.importImplicit", "true"); + p.setProperty("zeppelin.pyspark.python", "python"); + p.setProperty("zeppelin.dep.localrepo", Files.createTempDir().getAbsolutePath()); + + intpGroup = new InterpreterGroup(); + intpGroup.put("session_1", new LinkedList<Interpreter>()); + + SparkInterpreter sparkInterpreter = new SparkInterpreter(p); + intpGroup.get("session_1").add(sparkInterpreter); + sparkInterpreter.setInterpreterGroup(intpGroup); + sparkInterpreter.open(); + + iPySparkInterpreter = new IPySparkInterpreter(p); + intpGroup.get("session_1").add(iPySparkInterpreter); + iPySparkInterpreter.setInterpreterGroup(intpGroup); + iPySparkInterpreter.open(); + } + + + @After + public void tearDown() throws InterpreterException { + if (iPySparkInterpreter != null) { + iPySparkInterpreter.close(); + } + } + + @Test + public void testBasics() throws InterruptedException, IOException, InterpreterException { + // all the ipython test should pass too. + IPythonInterpreterTest.testInterpreter(iPySparkInterpreter); + + // rdd + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = iPySparkInterpreter.interpret("sc.range(1,10).sum()", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + List<InterpreterResultMessage> interpreterResultMessages = context.out.getInterpreterResultMessages(); + assertEquals("45", interpreterResultMessages.get(0).getData()); + + context = getInterpreterContext(); + result = iPySparkInterpreter.interpret("sc.version", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.getInterpreterResultMessages(); + // spark sql + context = getInterpreterContext(); + if (interpreterResultMessages.get(0).getData().startsWith("'1.") || + interpreterResultMessages.get(0).getData().startsWith("u'1.")) { + result = iPySparkInterpreter.interpret("df = sqlContext.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.getInterpreterResultMessages(); + assertEquals( + "+---+---+\n" + + "| _1| _2|\n" + + "+---+---+\n" + + "| 1| a|\n" + + "| 2| b|\n" + + "+---+---+\n\n", interpreterResultMessages.get(0).getData()); + } else { + result = iPySparkInterpreter.interpret("df = spark.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.getInterpreterResultMessages(); + assertEquals( + "+---+---+\n" + + "| _1| _2|\n" + + "+---+---+\n" + + "| 1| a|\n" + + "| 2| b|\n" + + "+---+---+\n\n", interpreterResultMessages.get(0).getData()); + } + + // cancel + final InterpreterContext context2 = getInterpreterContext(); + + Thread thread = new Thread() { + @Override + public void run() { + InterpreterResult result = iPySparkInterpreter.interpret("import time\nsc.range(1,10).foreach(lambda x: time.sleep(1))", context2); + assertEquals(InterpreterResult.Code.ERROR, result.code()); + List<InterpreterResultMessage> interpreterResultMessages = null; + try { + interpreterResultMessages = context2.out.getInterpreterResultMessages(); + assertTrue(interpreterResultMessages.get(0).getData().contains("KeyboardInterrupt")); + } catch (IOException e) { + e.printStackTrace(); + } + } + }; + thread.start(); + + // sleep 1 second to wait for the spark job starts + Thread.sleep(1000); + iPySparkInterpreter.cancel(context); + thread.join(); + + // completions + List<InterpreterCompletion> completions = iPySparkInterpreter.completion("sc.ran", 6, getInterpreterContext()); + assertEquals(1, completions.size()); + assertEquals("range", completions.get(0).getValue()); + + // pyspark streaming + + Class klass = py4j.GatewayServer.class; + URL location = klass.getResource('/' + klass.getName().replace('.', '/') + ".class"); + System.out.println("py4j location: " + location); + context = getInterpreterContext(); + result = iPySparkInterpreter.interpret( + "from pyspark.streaming import StreamingContext\n" + + "import time\n" + + "ssc = StreamingContext(sc, 1)\n" + + "rddQueue = []\n" + + "for i in range(5):\n" + + " rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)]\n" + + "inputStream = ssc.queueStream(rddQueue)\n" + + "mappedStream = inputStream.map(lambda x: (x % 10, 1))\n" + + "reducedStream = mappedStream.reduceByKey(lambda a, b: a + b)\n" + + "reducedStream.pprint()\n" + + "ssc.start()\n" + + "time.sleep(6)\n" + + "ssc.stop(stopSparkContext=False, stopGraceFully=True)", context); + Thread.sleep(1000); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.getInterpreterResultMessages(); + assertEquals(1, interpreterResultMessages.size()); +// assertTrue(interpreterResultMessages.get(0).getData().contains("(0, 100)")); + } + + private InterpreterContext getInterpreterContext() { + return new InterpreterContext( + "noteId", + "paragraphId", + "replName", + "paragraphTitle", + "paragraphText", + new AuthenticationInfo(), + new HashMap<String, Object>(), + new GUI(), + new GUI(), + null, + null, + null, + new InterpreterOutput(null)); + } +}