使用 Groovy 与 Apache Wayang 和 Apache Spark
作者: Paul King
发布时间: 2022-06-19 01:01PM
Apache Wayang (孵化器项目) 是一个用于大数据跨平台处理的 API。它提供了对其他平台的抽象,例如 Apache Spark 和 Apache Flink,以及一个默认的内置流式 "平台"。其目标是在编写代码时提供一致的开发体验,无论最终需要的是轻量级平台还是高可扩展平台。应用程序的执行是在逻辑计划中指定的,该计划也是平台无关的。Wayang 将把逻辑计划转换为一组物理操作符,由特定的底层处理平台执行。
威士忌聚类
我们将研究如何使用 Apache Wayang 和 Groovy 来帮助我们寻找完美的单一麦芽苏格兰威士忌。来自 86 家酿酒厂 生产的威士忌,由专家品酒师根据 12 个标准(酒体、甜味、麦芽味、烟熏味、果味等)进行了排名。我们将使用 KMeans 算法来计算质心。这类似于 Wayang 文档中的 KMeans 示例,但我们不是使用 2 维(x 和 y 坐标),而是使用 12 维,分别对应于我们的标准。重点是,它说明了典型的数据科学和机器学习算法,这些算法涉及迭代(典型的映射、过滤、归约式处理)。
KMeans 是一种标准的数据科学聚类技术。在我们的例子中,它将具有相似特征(根据 12 个标准)的威士忌分组到聚类中。如果我们有一款最喜欢的威士忌,很有可能我们可以通过查看同一聚类中的其他实例来找到类似的威士忌。如果我们想换换口味,我们可以寻找其他聚类中的威士忌。质心是聚类中心的虚拟 "点"。对我们来说,它反映了该聚类中威士忌每个标准的典型度量。
实现细节
我们将从定义一个 Point 记录开始
record Point(double[] pts) implements Serializable {
static Point fromLine(String line) { new Point(line.split(',')[2..-1]*.toDouble() as double[]) }
}
我们已将其设置为 Serializable
(稍后会详细介绍),并包含了一个 fromLine
工厂方法,以帮助我们从 CSV 文件中创建点。我们将自己完成此操作,而不是依赖其他可以提供帮助的库。对我们来说,它不是一个 2D 或 3D 点,而是 12D,对应于 12 个标准。我们只使用 double
数组,因此任何维度都受支持,但 12 来自我们数据文件中列的数量。
我们将定义一个相关的 TaggedPointCounter
记录。它类似于 Point
,但跟踪了一个 int
聚类 ID 和 long
计数,这些计数在聚类点时使用
record TaggedPointCounter(double[] pts, int cluster, long count) implements Serializable {
TaggedPointCounter plus(TaggedPointCounter that) {
new TaggedPointCounter((0..<pts.size()).collect{ pts[it] + that.pts[it] } as double[], cluster, count + that.count)
}
TaggedPointCounter average() {
new TaggedPointCounter(pts.collect{ double d -> d/count } as double[], cluster, 0)
}
}
我们有 plus
和 average
方法,这些方法在算法的映射/归约部分中很有用。
KMeans 算法的另一个方面是将点分配到与其最近质心相关的聚类。对于 2 维,回想一下毕达哥拉斯定理,这将是 x 平方加 y 平方开根号,其中 x 和 y 分别是点到质心在 x 和 y 维上的距离。我们将对所有维度进行相同的操作,并定义以下辅助类来捕获算法的这一部分
class SelectNearestCentroid implements ExtendedSerializableFunction<Point, TaggedPointCounter> {
Iterable<TaggedPointCounter> centroids
void open(ExecutionContext context) {
centroids = context.getBroadcast("centroids")
}
TaggedPointCounter apply(Point p) {
def minDistance = Double.POSITIVE_INFINITY
def nearestCentroidId = -1
for (c in centroids) {
def distance = sqrt((0..<p.pts.size()).collect{ p.pts[it] - c.pts[it] }.sum{ it ** 2 } as double)
if (distance < minDistance) {
minDistance = distance
nearestCentroidId = c.cluster
}
}
new TaggedPointCounter(p.pts, nearestCentroidId, 1)
}
}
在 Wayang 行话中,SelectNearestCentroid
类是一个 UDF,即用户定义函数。它代表一些可以根据运行操作的位置做出优化决策的功能块。
一旦我们开始使用 Spark,我们算法的映射/归约部分中的类将需要可序列化。动态 Groovy 中的方法闭包不可序列化。我们有一些选项可以避免使用它们。我将在这里展示一种方法,即在可能通常使用方法引用的位置使用一些辅助类。以下是一些辅助类
class Cluster implements SerializableFunction<TaggedPointCounter, Integer> {
Integer apply(TaggedPointCounter tpc) { tpc.cluster() }
}
class Average implements SerializableFunction<TaggedPointCounter, TaggedPointCounter> {
TaggedPointCounter apply(TaggedPointCounter tpc) { tpc.average() }
}
class Plus implements SerializableBinaryOperator<TaggedPointCounter> {
TaggedPointCounter apply(TaggedPointCounter tpc1, TaggedPointCounter tpc2) { tpc1.plus(tpc2) }
}
现在我们准备好了 KMeans 脚本
int k = 5
int iterations = 20
// read in data from our file
def url = WhiskeyWayang.classLoader.getResource('whiskey.csv').file
def pointsData = new File(url).readLines()[1..-1].collect{ Point.fromLine(it) }
def dims = pointsData[0].pts().size()
// create some random points as initial centroids
def r = new Random()
def initPts = (1..k).collect { (0..<dims).collect { r.nextGaussian() + 2 } as double[] }
// create planbuilder with Java and Spark enabled
def configuration = new Configuration()
def context = new WayangContext(configuration)
.withPlugin(Java.basicPlugin())
.withPlugin(Spark.basicPlugin())
def planBuilder = new JavaPlanBuilder(context, "KMeans ($url, k=$k, iterations=$iterations)")
def points = planBuilder
.loadCollection(pointsData).withName('Load points')
def initialCentroids = planBuilder
.loadCollection((0..<k).collect{ idx -> new TaggedPointCounter(initPts[idx], idx, 0) })
.withName("Load random centroids")
def finalCentroids = initialCentroids
.repeat(iterations, currentCentroids ->
points.map(new SelectNearestCentroid())
.withBroadcast(currentCentroids, "centroids").withName("Find nearest centroid")
.reduceByKey(new Cluster(), new Plus()).withName("Add up points")
.map(new Average()).withName("Average points")
.withOutputClass(TaggedPointCounter)).withName("Loop").collect()
println 'Centroids:'
finalCentroids.each { c ->
println "Cluster$c.cluster: ${c.pts.collect{ sprintf('%.3f', it) }.join(', ')}"
}
这里,k
是所需聚类的数量,iterations
是遍历 KMeans 循环的次数。pointsData
变量是从数据文件加载的 Point
实例列表。如果我们的数据集很大,我们将使用 readTextFile
方法而不是 loadCollection
。initPts
变量是我们初始质心的一些随机起始位置。由于是随机的,并且考虑到 KMeans 算法的工作方式,我们的某些聚类可能没有分配任何点。我们的算法通过在每次迭代中将所有点分配到其最近的当前质心,然后根据这些分配计算新的质心来工作。最后,我们输出结果。
使用基于 Java 流的平台运行
如前所述,Wayang 选择哪个平台运行我们的应用程序。它拥有众多功能,可以利用成本函数和负载估算器来影响和优化应用程序的运行方式。对于我们的简单示例,了解这一点就足够了,即即使我们指定了 Java 或 Spark 作为选项,Wayang 也知道对于我们的小型数据集,Java 流选项是最佳选择。
由于我们使用随机数据启动算法,因此我们预计每次运行脚本时结果都会略有不同,但以下是一个输出
> Task :WhiskeyWayang:run
Centroids:
Cluster0: 2.548, 2.419, 1.613, 0.194, 0.097, 1.871, 1.742, 1.774, 1.677, 1.935, 1.806, 1.613
Cluster2: 1.464, 2.679, 1.179, 0.321, 0.071, 0.786, 1.429, 0.429, 0.964, 1.643, 1.929, 2.179
Cluster3: 3.250, 1.500, 3.250, 3.000, 0.500, 0.250, 1.625, 0.375, 1.375, 1.375, 1.250, 0.250
Cluster4: 1.684, 1.842, 1.211, 0.421, 0.053, 1.316, 0.632, 0.737, 1.895, 2.000, 1.842, 1.737
...
如果绘制出来,它看起来像这样
如果您感兴趣,请查看本文末尾的仓库链接中的示例,以查看生成此质心蜘蛛图的代码,或查看此项目 GitHub 仓库中的 Jupyter/BeakerX 笔记本。
使用 Apache Spark 运行
鉴于我们的数据集大小很小,并且没有其他自定义配置,Wayang 将选择基于 Java 流的解决方案。我们可以使用 Wayang 优化功能来影响它选择的处理平台,但为了保持简单,我们将在配置中禁用 Java 流平台,方法是在代码中进行以下更改
...
def configuration = new Configuration()
def context = new WayangContext(configuration)
// .withPlugin(Java.basicPlugin()) (1)
.withPlugin(Spark.basicPlugin())
def planBuilder = new JavaPlanBuilder(context, "KMeans ($url, k=$k, iterations=$iterations)")
...
-
已禁用
现在,当我们运行应用程序时,输出将类似于以下内容(与之前类似的解决方案,但包含 1000 多行额外的 Spark 和 Wayang 日志信息 - 为了演示目的已截断)
[main] INFO org.apache.spark.SparkContext - Running Spark version 3.3.0 [main] INFO org.apache.spark.util.Utils - Successfully started service 'sparkDriver' on port 62081. ... Centroids: Cluster4: 1.414, 2.448, 0.966, 0.138, 0.034, 0.862, 1.000, 0.483, 1.345, 1.690, 2.103, 2.138 Cluster0: 2.773, 2.455, 1.455, 0.000, 0.000, 1.909, 1.682, 1.955, 2.091, 2.045, 2.136, 1.818 Cluster1: 1.762, 2.286, 1.571, 0.619, 0.143, 1.714, 1.333, 0.905, 1.190, 1.952, 1.095, 1.524 Cluster2: 3.250, 1.500, 3.250, 3.000, 0.500, 0.250, 1.625, 0.375, 1.375, 1.375, 1.250, 0.250 Cluster3: 2.167, 2.000, 2.167, 1.000, 0.333, 0.333, 2.000, 0.833, 0.833, 1.500, 2.333, 1.667 ... [shutdown-hook-0] INFO org.apache.spark.SparkContext - Successfully stopped SparkContext [shutdown-hook-0] INFO org.apache.spark.util.ShutdownHookManager - Shutdown hook called
讨论
Apache Wayang 的目标是允许开发人员编写平台无关的应用程序。虽然这在很大程度上是正确的,但抽象并不完美。例如,如果我知道我仅使用基于流的平台,则无需担心使任何类可序列化(这是 Spark 的要求)。在我们的示例中,我们可以省略 TaggedPointCounter
记录的 implements Serializable
部分,并且可以使用方法引用 TaggedPointCounter::average
来代替我们的 Average
辅助类。这不是对 Wayang 的批评,毕竟,如果您想编写跨平台的 UDF,您可能需要遵循一些规则。相反,它的意思是表明抽象通常会在边缘出现漏洞。有时这些漏洞可以被利用,有时它们是等待着不知情的开发人员的陷阱。
总而言之,如果使用基于 Java 流的平台,您可以在 JDK17(它使用原生记录)以及 JDK11 和 JDK8(Groovy 在其中提供模拟记录)上运行应用程序。此外,如果需要,我们可以进行大量简化。使用 Spark 处理平台时,潜在的简化措施不可用,我们可以在 JDK8 和 JDK11 上运行(Spark 尚未在 JDK17 上得到支持)。
结论
我们研究了如何使用 Apache Wayang 来实现一个 KMeans 算法,该算法要么以 JDK 流功能为后盾运行,要么以 Apache Spark 为后盾运行。Wayang API 为我们隐藏了一些在分布式平台上编写代码的复杂性,以及处理 Spark 平台的一些复杂性。抽象并不完美,但它们确实易于使用,并且在需要在平台之间迁移时提供了额外的保护。此外,它们还开辟了众多优化可能性。
Apache Wayang 是 Apache 的一个孵化器项目,在毕业之前还有很多工作要做,但之前已经进行了大量工作(之前称为 Rheem,并于 2015 年开始)。平台无关的应用程序是一个多年来一直渴望实现的目标,但难以实现。看到 Apache Wayang 在实现这一目标方面取得的进展将是一件令人兴奋的事情。
更多信息
-
包含源代码的仓库: WhiskeyWayang
-
包含使用 Apache Commons CSV、Weka、Smile、Tribuo 等各种库的类似示例的仓库: Whiskey
-
一个直接使用 Apache Spark 但使用
spark-mllib
库中内置的并行 KMeans 而不是手工制作的算法的类似示例: WhiskeySpark -
一个直接使用 Apache Ignite 但使用
ignite-ml
库中内置的集群 KMeans 而不是手工制作的算法的类似示例: WhiskeyIgnite