使用 Groovy™、Deep Java Library (DJL) 和 Apache MXNet 检测对象
发布时间:2022-08-01 11:52AM
这篇博文探讨了如何将 Apache Groovy 与 Deep Java Library (DJL) 结合使用,并以 Apache MXNet 引擎为后端来检测图像中的对象。(Apache MXNet 是 ASF 的一个孵化项目。)
深度学习
Deep Java Library (DJL) 和 Apache MXNet
与其编写自己的神经网络,不如使用像 DJL 这样的库,它们提供了高级抽象,可以在某种程度上自动化创建必要的神经网络层。DJL 是引擎无关的,因此它能够支持不同的后端,包括 Apache MXNet、PyTorch、TensorFlow 和 ONNX Runtime。我们将使用默认引擎,对于我们的应用程序(在撰写本文时)来说,它是 Apache MXNet。
Apache MXNet 提供了底层引擎。它支持命令式和符号执行,使用多 GPU 或多主机硬件对模型进行分布式训练,以及多种语言绑定。Groovy 完全兼容 Java 绑定。
在 Groovy 中使用 DJL
Groovy 使用 Java 绑定。可以考虑查看 DJL 的 Java 初学者教程——它们几乎可以直接用于 Groovy。
对于我们的例子,我们需要做的第一件事是下载我们想要运行对象检测模型的图像
Path tempDir = Files.createTempDirectory("resnetssd")
def imageName = 'dog-ssd.jpg'
Path localImage = tempDir.resolve(imageName)
def url = new URL("https://s3.amazonaws.com/model-server/inputs/$imageName")
DownloadUtils.download(url, localImage, new ProgressBar())
Image img = ImageFactory.instance.fromFile(localImage)
它恰好是一个众所周知的现有图像。我们将在一个临时目录中存储该图像的本地副本,并且我们将使用 DJL 附带的实用程序类,在图像下载时提供一个漂亮的进度条。DJL 提供了自己的图像类,因此我们将使用下载图像中的相应类创建一个实例。
接下来我们要配置我们的神经网络层
def criteria = Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(Image, DetectedObjects)
.optFilter("backbone", "resnet50")
.optEngine(Engine.defaultEngineName)
.optProgress(new ProgressBar())
.build()
DLJ 支持多种模型“应用程序”,包括图像分类、文字识别、情感分析、线性回归等。我们将选择“对象检测”。这种应用程序会寻找图像中已知对象的边界框。“types”配置选项标识我们的输入将是图像,输出将是检测到的对象。“filter”选项表示我们将使用 ResNet-50(一个 50 层深的卷积神经网络,常作为许多计算机视觉任务的骨干)。我们将“engine”设置为默认引擎,即 Apache MXNet。我们还配置了一个可选的进度条,以在模型运行时提供进度反馈。
现在我们已经整理好了配置,我们将用它来加载模型,然后使用模型进行对象预测
def detection = criteria.loadModel().withCloseable { model ->
model.newPredictor().predict(img)
}
detection.items().each { println it }
img.drawBoundingBoxes(detection)
为了更好地衡量,我们将在图像中绘制边界框。
接下来,我们将图像保存到文件中,并使用 Groovy 的 SwingBuilder 显示它。
Path imageSaved = tempDir.resolve('detected.png')
imageSaved.withOutputStream { os -> img.save(os, 'png') }
def saved = ImageIO.read(imageSaved.toFile())
new SwingBuilder().edt {
frame(title: "$detection.numberOfObjects detected objects",
size: [saved.width, saved.height],
defaultCloseOperation: DISPOSE_ON_CLOSE,
show: true) { label(icon: imageIcon(image: saved)) }
}
构建并运行我们的应用程序
我们的代码存储在一个名为 ObjectDetect.groovy
的源文件中。
我们使用 Gradle 作为我们的构建文件
apply plugin: 'groovy'
apply plugin: 'application'
repositories {
mavenCentral()
}
application {
mainClass = 'ObjectDetect'
}
dependencies {
implementation "ai.djl:api:0.18.0"
implementation "org.apache.groovy:groovy:4.0.4"
implementation "org.apache.groovy:groovy-swing:4.0.4"
runtimeOnly "ai.djl:model-zoo:0.18.0"
runtimeOnly "ai.djl.mxnet:mxnet-engine:0.18.0"
runtimeOnly "ai.djl.mxnet:mxnet-model-zoo:0.18.0"
runtimeOnly "ai.djl.mxnet:mxnet-native-auto:1.8.0"
runtimeOnly "org.apache.groovy:groovy-nio:4.0.4"
runtimeOnly "org.slf4j:slf4j-jdk14:1.7.36"
}
我们使用 gradle run
任务运行应用程序
paulk@pop-os:/extra/projects/groovy-data-science$ ./gradlew DLMXNet:run > Task :DeepLearningMxnet:run Downloading: 100% |████████████████████████████████████████| dog-ssd.jpg Loading: 100% |████████████████████████████████████████| ... class: "car", probability: 0.99991, bounds: [x=0.611, y=0.137, width=0.293, height=0.160] class: "bicycle", probability: 0.95385, bounds: [x=0.162, y=0.207, width=0.594, height=0.588] class: "dog", probability: 0.93752, bounds: [x=0.168, y=0.350, width=0.274, height=0.593]
显示的图像如下
更多信息
结论
我们已经研究了如何使用 Apache Groovy、DLJ 和 Apache MXNet 来检测图像中的对象。我们使用了基于丰富的深度学习模型的模型,但我们不需要深入了解模型或其神经网络层的细节。DLJ 和 Apache MXNet 为我们完成了繁重的工作。Groovy 为构建我们的应用程序提供了简单的编码体验。