Repository: spark
Updated Branches:
  refs/heads/master a4f4fbc8f -> 1f4c7f7ec


Graph primitives2

Hi guys,

I'm following Joey and Ankur's suggestions to add collectEdges and 
pickRandomVertex. I'm also adding the tests for collectEdges and refactoring 
one method getCycleGraph in GraphOpsSuite.scala.

Thank you,

semih

Author: Semih Salihoglu <[email protected]>

Closes #580 from semihsalihoglu/GraphPrimitives2 and squashes the following 
commits:

937d3ec [Semih Salihoglu] - Fixed the scalastyle errors.
a69a152 [Semih Salihoglu] - Adding collectEdges and pickRandomVertices. - 
Adding tests for collectEdges. - Refactoring a getCycle utility function for 
GraphOpsSuite.scala.
41265a6 [Semih Salihoglu] - Adding collectEdges and pickRandomVertex. - Adding 
tests for collectEdges. - Recycling a getCycle utility test file.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1f4c7f7e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1f4c7f7e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1f4c7f7e

Branch: refs/heads/master
Commit: 1f4c7f7ecc9d2393663fc4d059e71fe4c70bad84
Parents: a4f4fbc
Author: Semih Salihoglu <[email protected]>
Authored: Mon Feb 24 22:42:30 2014 -0800
Committer: Reynold Xin <[email protected]>
Committed: Mon Feb 24 22:42:30 2014 -0800

----------------------------------------------------------------------
 .../org/apache/spark/graphx/GraphOps.scala      |  59 +++++++-
 .../org/apache/spark/graphx/GraphOpsSuite.scala | 134 +++++++++++++++++--
 2 files changed, 183 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1f4c7f7e/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
----------------------------------------------------------------------
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala 
b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index 0fc1e4d..377d9d6 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -18,11 +18,11 @@
 package org.apache.spark.graphx
 
 import scala.reflect.ClassTag
-
 import org.apache.spark.SparkContext._
 import org.apache.spark.SparkException
 import org.apache.spark.graphx.lib._
 import org.apache.spark.rdd.RDD
+import scala.util.Random
 
 /**
  * Contains additional functionality for [[Graph]]. All operations are 
expressed in terms of the
@@ -138,6 +138,42 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: 
Graph[VD, ED]) extends Seriali
   } // end of collectNeighbor
 
   /**
+   * Returns an RDD that contains for each vertex v its local edges,
+   * i.e., the edges that are incident on v, in the user-specified direction.
+   * Warning: note that singleton vertices, those with no edges in the given
+   * direction will not be part of the return value.
+   *
+   * @note This function could be highly inefficient on power-law
+   * graphs where high degree vertices may force a large amount of
+   * information to be collected to a single location.
+   *
+   * @param edgeDirection the direction along which to collect
+   * the local edges of vertices
+   *
+   * @return the local edges for each vertex
+   */
+  def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = 
{
+    edgeDirection match {
+      case EdgeDirection.Either =>
+        graph.mapReduceTriplets[Array[Edge[ED]]](
+          edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, 
edge.attr))),
+                           (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, 
edge.attr)))),
+          (a, b) => a ++ b)
+      case EdgeDirection.In =>
+        graph.mapReduceTriplets[Array[Edge[ED]]](
+          edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, 
edge.attr)))),
+          (a, b) => a ++ b)
+      case EdgeDirection.Out =>
+        graph.mapReduceTriplets[Array[Edge[ED]]](
+          edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, 
edge.attr)))),
+          (a, b) => a ++ b)
+      case EdgeDirection.Both =>
+        throw new SparkException("collectEdges does not support 
EdgeDirection.Both. Use" +
+          "EdgeDirection.Either instead.")
+    }
+  }
+ 
+  /**
    * Join the vertices with an RDD and then apply a function from the
    * the vertex and RDD entry to a new vertex value.  The input table
    * should contain at most one entry for each vertex.  If no entry is
@@ -210,6 +246,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: 
Graph[VD, ED]) extends Seriali
   }
 
   /**
+   * Picks a random vertex from the graph and returns its ID.
+   */
+  def pickRandomVertex(): VertexId = {
+    val probability = 50 / graph.numVertices
+    var found = false
+    var retVal: VertexId = null.asInstanceOf[VertexId]
+    while (!found) {
+      val selectedVertices = graph.vertices.flatMap { vidVvals =>
+        if (Random.nextDouble() < probability) { Some(vidVvals._1) }
+        else { None }
+      }
+      if (selectedVertices.count > 1) {
+        found = true
+        val collectedVertices = selectedVertices.collect()
+        retVal = collectedVertices(Random.nextInt(collectedVertices.size))
+      }
+    }
+   retVal
+  }
+
+  /**
    * Execute a Pregel-like iterative vertex-parallel abstraction.  The
    * user-defined vertex-program `vprog` is executed in parallel on
    * each vertex receiving any inbound messages and computing a new

http://git-wip-us.apache.org/repos/asf/spark/blob/1f4c7f7e/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
----------------------------------------------------------------------
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala 
b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
index bc2ad56..6386306 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -42,21 +42,20 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext 
{
 
   test("collectNeighborIds") {
     withSpark { sc =>
-      val chain = (0 until 100).map(x => (x, (x+1)%100) )
-      val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, 
d.toLong) }
-      val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+      val graph = getCycleGraph(sc, 100)
       val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache()
-      assert(nbrs.count === chain.size)
+      assert(nbrs.count === 100)
       assert(graph.numVertices === nbrs.count)
       nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) }
-      nbrs.collect.foreach { case (vid, nbrs) =>
-        val s = nbrs.toSet
-        assert(s.contains((vid + 1) % 100))
-        assert(s.contains(if (vid > 0) vid - 1 else 99 ))
+      nbrs.collect.foreach {
+        case (vid, nbrs) =>
+          val s = nbrs.toSet
+          assert(s.contains((vid + 1) % 100))
+          assert(s.contains(if (vid > 0) vid - 1 else 99))
       }
     }
   }
-
+  
   test ("filter") {
     withSpark { sc =>
       val n = 5
@@ -80,4 +79,121 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext 
{
     }
   }
 
+  test("collectEdgesCycleDirectionOut") {
+    withSpark { sc =>
+      val graph = getCycleGraph(sc, 100)
+      val edges = graph.collectEdges(EdgeDirection.Out).cache()
+      assert(edges.count == 100)
+      edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+      edges.collect.foreach {
+        case (vid, edges) =>
+          val s = edges.toSet
+          val edgeDstIds = s.map(e => e.dstId)
+          assert(edgeDstIds.contains((vid + 1) % 100))
+      }
+    }
+  }
+
+  test("collectEdgesCycleDirectionIn") {
+    withSpark { sc =>
+      val graph = getCycleGraph(sc, 100)
+      val edges = graph.collectEdges(EdgeDirection.In).cache()
+      assert(edges.count == 100)
+      edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+      edges.collect.foreach {
+        case (vid, edges) =>
+          val s = edges.toSet
+          val edgeSrcIds = s.map(e => e.srcId)
+          assert(edgeSrcIds.contains(if (vid > 0) vid - 1 else 99))
+      }
+    }
+  }
+
+  test("collectEdgesCycleDirectionEither") {
+    withSpark { sc =>
+      val graph = getCycleGraph(sc, 100)
+      val edges = graph.collectEdges(EdgeDirection.Either).cache()
+      assert(edges.count == 100)
+      edges.collect.foreach { case (vid, edges) => assert(edges.size == 2) }
+      edges.collect.foreach {
+        case (vid, edges) =>
+          val s = edges.toSet
+          val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
+          assert(edgeIds.contains((vid + 1) % 100))
+          assert(edgeIds.contains(if (vid > 0) vid - 1 else 99))
+      }
+    }
+  }
+
+  test("collectEdgesChainDirectionOut") {
+    withSpark { sc =>
+      val graph = getChainGraph(sc, 50)
+      val edges = graph.collectEdges(EdgeDirection.Out).cache()
+      assert(edges.count == 49)
+      edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+      edges.collect.foreach {
+        case (vid, edges) =>
+          val s = edges.toSet
+          val edgeDstIds = s.map(e => e.dstId)
+          assert(edgeDstIds.contains(vid + 1))
+      }
+    }
+  }
+
+  test("collectEdgesChainDirectionIn") {
+    withSpark { sc =>
+      val graph = getChainGraph(sc, 50)
+      val edges = graph.collectEdges(EdgeDirection.In).cache()
+      // We expect only 49 because collectEdges does not return vertices that 
do
+      // not have any edges in the specified direction.
+      assert(edges.count == 49)
+      edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+      edges.collect.foreach {
+        case (vid, edges) =>
+          val s = edges.toSet
+          val edgeDstIds = s.map(e => e.srcId)
+          assert(edgeDstIds.contains((vid - 1) % 100))
+      }
+    }
+  }
+
+  test("collectEdgesChainDirectionEither") {
+    withSpark { sc =>
+      val graph = getChainGraph(sc, 50)
+      val edges = graph.collectEdges(EdgeDirection.Either).cache()
+      // We expect only 49 because collectEdges does not return vertices that 
do
+      // not have any edges in the specified direction.
+      assert(edges.count === 50)
+      edges.collect.foreach {
+        case (vid, edges) => if (vid > 0 && vid < 49) assert(edges.size == 2)
+        else assert(edges.size == 1)
+      }
+      edges.collect.foreach {
+        case (vid, edges) =>
+          val s = edges.toSet
+          val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
+          if (vid == 0) { assert(edgeIds.contains(1)) }
+          else if (vid == 49) { assert(edgeIds.contains(48)) }
+          else {
+            assert(edgeIds.contains(vid + 1))
+            assert(edgeIds.contains(vid - 1))
+          }
+      }
+    }
+  }
+
+  private def getCycleGraph(sc: SparkContext, numVertices: Int): Graph[Double, 
Int] = {
+    val cycle = (0 until numVertices).map(x => (x, (x + 1) % numVertices))
+    getGraphFromSeq(sc, cycle)
+  }
+
+  private def getChainGraph(sc: SparkContext, numVertices: Int): Graph[Double, 
Int] = {
+    val chain = (0 until numVertices - 1).map(x => (x, (x + 1)))
+    getGraphFromSeq(sc, chain)
+  }
+
+  private def getGraphFromSeq(sc: SparkContext, seq: IndexedSeq[(Int, Int)]): 
Graph[Double, Int] = {
+    val rawEdges = sc.parallelize(seq, 3).map { case (s, d) => (s.toLong, 
d.toLong) }
+    Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+  }
 }

Reply via email to