暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

flink model serving 增加支持pytorch功能

IT技术小输出 2021-05-09
2539
很久没有更新了,这段时间实在忙,终于把事情告一段落(*^▽^*)

现在对算法模型的训练和推理都趋向于实时训练/推理,效率更好,吞吐量更大,时延更低;而且可以在线的更新模型,可以在一套系统中,使用不同的模型进行实时在线推理。flink作为现在实时数据处理的代表框架,提供了一套丐版的model serving来实现在线推理/训练和实时更新,并且和kafka集成实现增量在线训练。

但是问题在于,flink开源的FLIP-23,仅仅提供了一套支持TensorFlow、pmml格式模型的核心代码;因为FLIP-23的源码使用Java开发,模型训练,基本上都是通过Python训练,要兼容其他框架训练的模型,就必须解决Java代码加载Python模型的问题。

当我想对pytorch模型提供model serving来进行实时推理和模型更新,首先考虑就是解决在Java中加载pytorch的问题。pytorch模型是无法转换为pmml格式的,在进行调研之后,找到了亚马逊的开源框架DJL,它实现了在Java中调用加载pytorch模型、Keras模型等。

DJL是一个很新的开源框架,20年5月份提供支持pytorch的功能,所以在使用中会踩一些坑,很难去找到解决的方法,我在使用中遇到了一些问题,包括引入的依赖平台和自己电脑的cuda不兼容,操作系统不兼容(在win7上使用真的很费劲),等等很多问题,有一个bug甚至翻墙都没有找到解决的方法;DJL的使用需要你对算法模型的一些参数了解的比较清楚,我估计有可能是我对算法模型的一些参数并不了解,导致出现的问题。下面贴上DJL的介绍:
Deep Java Library (DJL) is an open-source, high-level, engine-agnostic Java framework for deep learning. DJL is designed to be easy to get started with and simple to use for Java developers. DJL provides a native Java development experience and functions like any other regular Java library.


You don't have to be machine learning/deep learning expert to get started. You can use your existing Java expertise as an on-ramp to learn and use machine learning and deep learning. You can use your favorite IDE to build, train, and deploy your models. DJL makes it easy to integrate these models with your Java applications.


Because DJL is deep learning engine agnostic, you don't have to make a choice between engines when creating your projects. You can switch engines at any point. To ensure the best performance, DJL also provides automatic CPU/GPU choice based on hardware configuration.


DJL's ergonomic API interface is designed to guide you with best practices to accomplish deep learning tasks. The following pseudocode demonstrates running inference:


// Assume user uses a pre-trained model from model zoo, they just need to load it
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION) // find object dection model
.setTypes(Image.class, Classifications.class) // define input and output
.optFilter("backbone", "resnet50") // choose network architecture
.build();


try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria)) {
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
Image img = ImageFactory.getInstance().fromUrl("http://..."); // read image
Classifications result = predictor.predict(img);


// get the classification and probability
...
}
}
The following pseudocode demonstrates running training:


// Construct your neural network with built-in blocks
Block block = new Mlp(28, 28);


try (Model model = Model.newInstance("mlp")) { // Create an empty model
model.setBlock(block); // set neural network to model


// Get training and validation dataset (MNIST dataset)
Dataset trainingSet = new Mnist.Builder().setUsage(Usage.TRAIN) ... .build();
Dataset validateSet = new Mnist.Builder().setUsage(Usage.TEST) ... .build();


// Setup training configurations, such as Initializer, Optimizer, Loss ...
TrainingConfig config = setupTrainingConfig();
try (Trainer trainer = model.newTrainer(config)) {
/*
* Configure input shape based on dataset to initialize the trainer.
* 1st axis is batch axis, we can use 1 for initialization.
* MNIST is 28x28 grayscale image and pre processed into 28 * 28 NDArray.
*/
Shape inputShape = new Shape(1, 28 * 28);
trainer.initialize(new Shape[] {inputShape});


EasyTrain.fit(trainer, epoch, trainingSet, validateSet);
}


// Save the model
model.save(modelDir, "mlp");
}

在GitHub上的地址:https://github.com/deepjavalibrary/djl

因为使用DJL出现的一些BUG没有妥善的解决,我又去寻找一些其他的方法。之后发现可以将pytorch模型转换为onnx的格式:
Open Neural Network Exchange (ONNX) is an open ecosystem that empowers AI developers
to choose the right tools as their project evolves. ONNX provides an open source format
for AI models, both deep learning and traditional ML. It defines an extensible computation
graph model, as well as definitions of built-in operators and standard data types. Currently
we focus on the capabilities needed for inferencing (scoring).


ONNX is widely supported and can be found in many frameworks,
tools, and hardware. Enabling interoperability between different
frameworks and streamlining the path from research to production
helps increase the speed of innovation in the AI community.
We invite the community to join us and further evolve ONNX.

在将pytorch模型转换为onnx之后,就可以在Java中直接调用模型进行推理,但是亲测在win7系统会报错,win10可以正常运行。

pytorch模型要转成onnx,一般使用Python就可以直接转了:


model_inits = {
'rf_lw50_voc' : rf_lw50,
}
n_classes=2
models = dict()
for key,fun in six.iteritems(model_inits):
# 这里调用的其实是def rf_lw50(num_classes, imagenet=False, pretrained=True, **kwargs)
net = fun(n_classes, pretrained=True).eval()
if has_cuda:
net = net.cuda()
models[key] = net
model=net
#给定一个确定的输入 实际上可以随机初始化一个和需要的尺寸一致的numpy
img_path="/home/ahhh/PycharmProjects/light-weight-refinenet/examples/imgs/blind/1110a.png"
img = np.array(Image.open(img_path))
input = (torch.Tensor(prepare_img(img).transpose(2, 0, 1)[None])).float()
if has_cuda:
input = input.cuda()
# onnx模型文件输出
torch_out = torch.onnx._export(model, input, "lw50.onnx",export_params=True)

上面的脚本基本上就是将一个pytorch模型转换为onnx格式的大概步骤。

java调用onnx 可以看这里:
https://github.com/microsoft/onnxruntime-openenclave/blob/openenclave-public/java/src/test/java/ai/onnxruntime/InferenceTest.java#L66

解决了Java调用pytorch模型的问题,剩下的就是对flink model serving开源代码的开发了,对核心类和接口进行修改或增加,添加支持onnx格式模型的功能。我二次开发后的源码就不贴上来了(*^▽^*)。
文章转载自IT技术小输出,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论