使用 Oracle 23ai 向量数据类型与 Groovy 对鸢尾花进行分类

作者:Paul King
发布日期:2024-06-30 11:21PM


鸢尾花 一个经典的数据科学 数据集 收集了鸢尾花的特征。它记录了三种物种SetosaVersicolorVirginica)的萼片花瓣宽度长度

Iris 项目 位于 groovy-data-science 仓库 中,专门用于此示例。它包含多个 Groovy 脚本和一个 Jupyter/BeakerX 笔记本,重点介绍了这个示例,比较和对比了各种库和各种分类算法。

之前的一篇 博客文章 使用多个深度学习库描述了此示例,并给出了利用 GraalVM 的解决方案。在本篇博客文章中,我们将研究如何使用 Oracle 23ai 的向量数据类型和向量 AI 查询对我们数据集的一部分进行分类。

一般来说,许多机器学习/AI 算法处理信息向量。这些信息可能是实际数据值,比如我们花卉的特征,也可能是数据值的投影,或者文本、视频、图像或音频文件的关键信息的表示。后者通常被称为嵌入。

对我们而言,我们将找到具有相似特征的花卉。在其他相似性搜索场景中,我们可以检测欺诈性交易、找到客户推荐,或者根据图像嵌入的“接近度”找到相似的图像。

数据集

前面提到的 Iris 项目 展示了如何使用各种技术对鸢尾花数据集进行分类。特别是,一个例子使用了 Smile 库的 kNN 分类算法。该示例使用整个数据集来训练模型,然后在整个数据集上运行模型以评估其准确性。如分类与花瓣大小的图表所示,该算法在 Virginica 和 Versicolor 分组重叠处的数据点上存在一些问题。

Graph of predicted vs actual Iris flower classifications

如果我们查看分类与萼片大小的关系,我们可以看到更多的混淆可能性。

Graph of predicted vs actual Iris flower classifications

紫色和绿色点表示分类错误的花卉。

相应的混淆矩阵也显示了这些结果。

Confusion matrix:
ROW=truth and COL=predicted
class  0 |      50 |       0 |       0 |
class  1 |       0 |      47 |       3 |
class  2 |       0 |       3 |      47 |

一般来说,在原始数据集上运行模型可能并不理想,因为我们无法获得准确的错误计算结果,但这确实突出了我们数据的一些重要信息。在我们的例子中,我们可以看到 Virginica 和 Versacolor 类变得拥挤,这两个组重叠处的数据点可能会出现错误分类的情况。

数据库解决方案

我们的数据存储在 CSV 文件中。

iris CSV file

它恰好包含 50 个鸢尾花的三个类。首先,我们从 CSV 文件中加载数据集,跳过标题行并对剩余的行进行洗牌,以确保我们测试的是三个鸢尾花类的随机混合。

var file = getClass().classLoader.getResource('iris_data.csv').file as File
var rows = file.readLines()[1..-1].shuffled() // skip header and shuffle
var (training, test) = rows.chop(rows.size() * .8 as int, -1)

洗牌后,我们将数据分成两个集合。前 80% 将进入数据库。它对应于数据科学术语中的“训练”数据。最后 20% 将对应于我们的“测试”数据。

接下来,我们定义 SQL 连接所需的 정보。

var url = 'jdbc:oracle:thin:@localhost:1521/FREEPDB1'
var user = 'some_user'
var password = 'some_password'
var driver = 'oracle.jdbc.driver.OracleDriver'

接下来,我们创建数据库连接,并使用它来插入“训练”行,然后针对“测试”行进行测试。

Sql.withInstance(url, user, password, driver) { sql ->
    training.each { row ->
        var data = row.split(',')
        var features = data[0..-2].toString()
        sql.executeInsert """
            INSERT INTO Iris (class, features) VALUES (${data[-1]}, $features)
        """
    }
    printf "%-20s%-20s%-20s%n", 'Actual', 'Predicted', 'Confidence'
    test.each { row ->
        var data = row.split(',')
        var features = VECTOR.ofFloat64Values(data[0..-2]*.toDouble() as double[])
        var closest10 = sql.rows """
        select class from Iris
        order by vector_distance(features, $features, EUCLIDEAN)
        fetch first 10 rows only
        """
        var results = closest10
                .groupBy{ e -> e.CLASS }
                .collectEntries { e -> [e.key, e.value.size()]}
        var predicted = results.max{ e -> e.value }
        printf "%-20s%-20s%5d%n", data[-1], predicted.key, predicted.value * 10
    }
}

这段代码有一些有趣的方面。

  • 当我们插入数据时,我们只使用了字符串。由于 `features` 列的类型已知,因此它会自动进行转换。

  • 或者,我们可以显式地处理类型,如查询中使用 `VECTOR.ofFloat64Values` 所示。

  • 可能看起来奇怪的是,实际上并没有像传统算法那样进行模型训练。相反,SQL 查询中的 `vector_distance` 函数调用基于 kNN 的搜索来查找结果。在我们的例子中,我们请求了前 10 个最接近的点。

  • 我们在查询中使用了 `EUCLIDEAN` 距离度量,但如果我们选择了 `EUCLIDEAN_SQUARED`,我们会得到类似的结果,但执行速度更快。直观地说,如果两个点彼此靠近,则这两个度量都将很小,而如果两个点不相关,则这两个度量都将很大。如果我们的特征特性是标准化的,我们预计会得到相同的结果。

  • `COSINE` 距离度量也工作得很好。直观地说,如果重要的不是萼片和花瓣的实际大小,而是它们的比率,那么相似的花卉将在我们的二维图上位于相同的角度,而这正是 `COSINE` 度量的对象。对于这个数据集,两者都很重要,但两种度量都能得到所有(或几乎所有)正确的答案。

  • 一旦我们获得了前 10 个最接近的点,类预测就简单地从 10 个结果中得到预测次数最多的类。我们的置信度表示前 10 个结果中有多少与预测一致。

输出看起来像这样

Actual              Predicted           Confidence
Iris-virginica      Iris-virginica         90
Iris-virginica      Iris-virginica         90
Iris-virginica      Iris-virginica        100
Iris-virginica      Iris-virginica        100
Iris-virginica      Iris-versicolor        60
Iris-setosa         Iris-setosa           100
Iris-setosa         Iris-setosa           100
Iris-setosa         Iris-setosa           100
Iris-setosa         Iris-setosa           100
Iris-setosa         Iris-setosa           100
Iris-virginica      Iris-virginica        100
Iris-versicolor     Iris-versicolor       100
Iris-versicolor     Iris-versicolor       100
Iris-versicolor     Iris-versicolor        70
Iris-virginica      Iris-virginica        100
Iris-virginica      Iris-virginica        100
Iris-setosa         Iris-setosa           100
Iris-versicolor     Iris-versicolor       100
Iris-virginica      Iris-virginica        100
Iris-versicolor     Iris-versicolor       100
Iris-setosa         Iris-setosa           100
Iris-setosa         Iris-setosa           100
Iris-versicolor     Iris-versicolor       100
Iris-virginica      Iris-virginica         90
Iris-setosa         Iris-setosa           100
Iris-virginica      Iris-virginica         90
Iris-setosa         Iris-setosa           100
Iris-setosa         Iris-setosa           100
Iris-virginica      Iris-virginica        100
Iris-virginica      Iris-virginica        100

只有一个结果是错误的(上面第一个 **粗体** 行)。由于我们对数据进行了随机洗牌,因此在其他运行中,我们可能会得到不同数量的错误结果。

我们可以通过在三维图中绘制 10 个最接近的点来可视化距离查询是如何工作的。我们将对 70% 置信度情况下返回的点(上面第二个 **粗体** 行)执行此操作。

closest 10 points

这是一个主成分分析 (PCA) 图,它将我们的 4 个维度(花瓣宽度和长度、萼片宽度和长度)投影到 3 个维度上。

红色的大圆点是我们测试查询特征的投影。小圆点是数据集中未选择的点。中等圆点是 `vector_distance` 查询返回的点。返回了 7 个 Versicolor 点(蓝色)和 3 个 Virginica 点(橙色)。我们知道该数据点的结果是 Versicolor。

结论

我们快速地了解了如何使用来自 Oracle 23ai 的向量数据类型和 Apache Groovy。