使用 Oracle 23ai 向量数据类型与 Groovy 对鸢尾花进行分类
作者:Paul King
发布日期:2024-06-30 11:21PM
一个经典的数据科学 数据集 收集了鸢尾花的特征。它记录了三种物种(Setosa、Versicolor 和 Virginica)的萼片和花瓣的宽度和长度。
Iris 项目 位于 groovy-data-science 仓库 中,专门用于此示例。它包含多个 Groovy 脚本和一个 Jupyter/BeakerX 笔记本,重点介绍了这个示例,比较和对比了各种库和各种分类算法。
之前的一篇 博客文章 使用多个深度学习库描述了此示例,并给出了利用 GraalVM 的解决方案。在本篇博客文章中,我们将研究如何使用 Oracle 23ai 的向量数据类型和向量 AI 查询对我们数据集的一部分进行分类。
一般来说,许多机器学习/AI 算法处理信息向量。这些信息可能是实际数据值,比如我们花卉的特征,也可能是数据值的投影,或者文本、视频、图像或音频文件的关键信息的表示。后者通常被称为嵌入。
对我们而言,我们将找到具有相似特征的花卉。在其他相似性搜索场景中,我们可以检测欺诈性交易、找到客户推荐,或者根据图像嵌入的“接近度”找到相似的图像。
数据集
前面提到的 Iris 项目 展示了如何使用各种技术对鸢尾花数据集进行分类。特别是,一个例子使用了 Smile 库的 kNN 分类算法。该示例使用整个数据集来训练模型,然后在整个数据集上运行模型以评估其准确性。如分类与花瓣大小的图表所示,该算法在 Virginica 和 Versicolor 分组重叠处的数据点上存在一些问题。
如果我们查看分类与萼片大小的关系,我们可以看到更多的混淆可能性。
紫色和绿色点表示分类错误的花卉。
相应的混淆矩阵也显示了这些结果。
Confusion matrix: ROW=truth and COL=predicted class 0 | 50 | 0 | 0 | class 1 | 0 | 47 | 3 | class 2 | 0 | 3 | 47 |
一般来说,在原始数据集上运行模型可能并不理想,因为我们无法获得准确的错误计算结果,但这确实突出了我们数据的一些重要信息。在我们的例子中,我们可以看到 Virginica 和 Versacolor 类变得拥挤,这两个组重叠处的数据点可能会出现错误分类的情况。
数据库解决方案
我们的数据存储在 CSV 文件中。
它恰好包含 50 个鸢尾花的三个类。首先,我们从 CSV 文件中加载数据集,跳过标题行并对剩余的行进行洗牌,以确保我们测试的是三个鸢尾花类的随机混合。
var file = getClass().classLoader.getResource('iris_data.csv').file as File
var rows = file.readLines()[1..-1].shuffled() // skip header and shuffle
var (training, test) = rows.chop(rows.size() * .8 as int, -1)
洗牌后,我们将数据分成两个集合。前 80% 将进入数据库。它对应于数据科学术语中的“训练”数据。最后 20% 将对应于我们的“测试”数据。
接下来,我们定义 SQL 连接所需的 정보。
var url = 'jdbc:oracle:thin:@localhost:1521/FREEPDB1'
var user = 'some_user'
var password = 'some_password'
var driver = 'oracle.jdbc.driver.OracleDriver'
接下来,我们创建数据库连接,并使用它来插入“训练”行,然后针对“测试”行进行测试。
Sql.withInstance(url, user, password, driver) { sql ->
training.each { row ->
var data = row.split(',')
var features = data[0..-2].toString()
sql.executeInsert """
INSERT INTO Iris (class, features) VALUES (${data[-1]}, $features)
"""
}
printf "%-20s%-20s%-20s%n", 'Actual', 'Predicted', 'Confidence'
test.each { row ->
var data = row.split(',')
var features = VECTOR.ofFloat64Values(data[0..-2]*.toDouble() as double[])
var closest10 = sql.rows """
select class from Iris
order by vector_distance(features, $features, EUCLIDEAN)
fetch first 10 rows only
"""
var results = closest10
.groupBy{ e -> e.CLASS }
.collectEntries { e -> [e.key, e.value.size()]}
var predicted = results.max{ e -> e.value }
printf "%-20s%-20s%5d%n", data[-1], predicted.key, predicted.value * 10
}
}
这段代码有一些有趣的方面。
-
当我们插入数据时,我们只使用了字符串。由于 `features` 列的类型已知,因此它会自动进行转换。
-
或者,我们可以显式地处理类型,如查询中使用 `VECTOR.ofFloat64Values` 所示。
-
可能看起来奇怪的是,实际上并没有像传统算法那样进行模型训练。相反,SQL 查询中的 `vector_distance` 函数调用基于 kNN 的搜索来查找结果。在我们的例子中,我们请求了前 10 个最接近的点。
-
我们在查询中使用了 `EUCLIDEAN` 距离度量,但如果我们选择了 `EUCLIDEAN_SQUARED`,我们会得到类似的结果,但执行速度更快。直观地说,如果两个点彼此靠近,则这两个度量都将很小,而如果两个点不相关,则这两个度量都将很大。如果我们的特征特性是标准化的,我们预计会得到相同的结果。
-
`COSINE` 距离度量也工作得很好。直观地说,如果重要的不是萼片和花瓣的实际大小,而是它们的比率,那么相似的花卉将在我们的二维图上位于相同的角度,而这正是 `COSINE` 度量的对象。对于这个数据集,两者都很重要,但两种度量都能得到所有(或几乎所有)正确的答案。
-
一旦我们获得了前 10 个最接近的点,类预测就简单地从 10 个结果中得到预测次数最多的类。我们的置信度表示前 10 个结果中有多少与预测一致。
输出看起来像这样
Actual Predicted Confidence Iris-virginica Iris-virginica 90 Iris-virginica Iris-virginica 90 Iris-virginica Iris-virginica 100 Iris-virginica Iris-virginica 100 Iris-virginica Iris-versicolor 60 Iris-setosa Iris-setosa 100 Iris-setosa Iris-setosa 100 Iris-setosa Iris-setosa 100 Iris-setosa Iris-setosa 100 Iris-setosa Iris-setosa 100 Iris-virginica Iris-virginica 100 Iris-versicolor Iris-versicolor 100 Iris-versicolor Iris-versicolor 100 Iris-versicolor Iris-versicolor 70 Iris-virginica Iris-virginica 100 Iris-virginica Iris-virginica 100 Iris-setosa Iris-setosa 100 Iris-versicolor Iris-versicolor 100 Iris-virginica Iris-virginica 100 Iris-versicolor Iris-versicolor 100 Iris-setosa Iris-setosa 100 Iris-setosa Iris-setosa 100 Iris-versicolor Iris-versicolor 100 Iris-virginica Iris-virginica 90 Iris-setosa Iris-setosa 100 Iris-virginica Iris-virginica 90 Iris-setosa Iris-setosa 100 Iris-setosa Iris-setosa 100 Iris-virginica Iris-virginica 100 Iris-virginica Iris-virginica 100
只有一个结果是错误的(上面第一个 **粗体** 行)。由于我们对数据进行了随机洗牌,因此在其他运行中,我们可能会得到不同数量的错误结果。
我们可以通过在三维图中绘制 10 个最接近的点来可视化距离查询是如何工作的。我们将对 70% 置信度情况下返回的点(上面第二个 **粗体** 行)执行此操作。
这是一个主成分分析 (PCA) 图,它将我们的 4 个维度(花瓣宽度和长度、萼片宽度和长度)投影到 3 个维度上。
红色的大圆点是我们测试查询特征的投影。小圆点是数据集中未选择的点。中等圆点是 `vector_distance` 查询返回的点。返回了 7 个 Versicolor 点(蓝色)和 3 个 Virginica 点(橙色)。我们知道该数据点的结果是 Versicolor。
结论
我们快速地了解了如何使用来自 Oracle 23ai 的向量数据类型和 Apache Groovy。