本文主要分为两部分,第一部分为使用moblienetv2 ssd lite训练一个识别扑克牌的模型,第2部分就是将训练好的模型部署到手机上运行,具体的运行效果如下所示:
pc端运行效果:
20220316
b站地址:视频链接
移动端:
扑克牌识别
app下载地址:下载地址
在这里分享一下数据集链接;
https://cloud.189.cn/t/eYRzu27ne6za (访问码:upa5)
该数据集制作不易,如有需要可前往下载。
1 网络训练
可以使用不同的网络,本文使用的是mobilenetv2-ssd-lite。具体可参考该github。
训练好模型以后将其转换成onnx格式。
onnx模型下载地址:
链接:https://pan.baidu.com/s/1kchAu_0B3KfJ3WkYY9IDXQ
提取码:g9fq
使用opencv加载onnx模型进行前向推理,代码如下所示:
import cv2
import numpy as np
label_dic=list({'card1':'0001','card2':'0002','card3':'0003','card4':'0004','card5':'0005',
'card6':'0006','card7':'0007','card8':'0008','card9':'0009','card10':'0010',
'card11':'0011','card12':'0012','card13':'0013',
'second1':'0014','second2':'0015','second3':'0016','second4':'0017','second5':'0018',
'second6':'0019','second7':'0020','second8':'0021','second9':'0022','second10':'0023',
'second11':'0024','second12':'0025','second13':'0026',
'three1':'0027','three2':'0028','three3':'0029','three4':'0030','three5':'0031','three6':'0032',
'three7':'0033','three8':'0034','three9':'0035','three10':'0036','three11':'0037','three12':'0038',
'three13':'0039',
'four1':'0040','four2':'0041','four3':'0042','four4':'0043','four5':'0044','four6':'0045','four7':'0046',
'four8':'0047','four9':'0048','four10':'0049','four11':'0050','four12':'0051','four13':'0052',
'xiao':'0053','da':'0054'
})
net=cv2.dnn.readNetFromONNX('./models/mb2-ssd-lite.onnx')
def dector(ori_image):
h, w = ori_image.shape[:2]
# ori_image=cv2.resize(ori_image,(300,300))
blob = cv2.dnn.blobFromImage(ori_image,
scalefactor=1.0 / 128,
size=(300, 300),
mean=[127, 127, 127],
swapRB=True,
crop=False)
# Run a model
net.setInput(blob)
# net.getUnconnectedOutLayersNames()
# boxes
aa = net.getUnconnectedOutLayersNames()
out = net.forward(['scores', 'boxes'])
#scores =>(3000,54) boxes=>(3000,4)
scores = out[0][0][...,1::]
boxes = out[1][0]
#3000
max_score=np.max(scores,axis=-1)
# mask=max_score>0.4
max_index=np.argmax(scores,axis=-1)
#
# pick_score=max_score[mask]
# pick_label=max_index[mask]
# pick_box=boxes[mask]
# out=detect(ori_image)
box_id = cv2.dnn.NMSBoxes(boxes.tolist(), max_score.tolist(), 0.4, 0.45)
if len(box_id)==0:
return ori_image
boxes = boxes[box_id]
confidences = max_score[box_id]
label=max_index[box_id]
for box,confidence,la in zip(boxes,confidences,label):
box=box*np.array([w,h,w,h])
box=box.astype(np.int32)
cv2.rectangle(ori_image,(box[0],box[1]),(box[2],box[3]),(255,255,0),2)
cv2.putText(ori_image,f'{label_dic[la]}:{round(confidence,3)}',
(box[0]+20, box[1]+40),
cv2.FONT_HERSHEY_SIMPLEX,
1, # font scale
(255, 0, 255),
2)
return ori_image
cap = cv2.VideoCapture(r'E:\card.mp4')
while True:
r,ori_image=cap.read()
# cv2.imshow('a231',ori_image)
# cv2.waitKey(0)
dector(ori_image)
cv2.imshow('a4232',ori_image)
cv2.waitKey(1)
**
2 移动端部署
使用opencv加载onnx模型,进行移动端部署。
**
mainactivaty.java代码
package com.myapp.puke;
import androidx.appcompat.app.AppCompatActivity;
import android.os.Bundle;
import android.Manifest;
import android.content.Context;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.media.MediaPlayer;
import android.os.Build;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.view.Window;
import android.view.WindowManager;
import android.widget.Toast;
import org.opencv.android.BaseLoaderCallback;
import org.opencv.android.CameraActivity;
import org.opencv.android.CameraBridgeViewBase;
import org.opencv.android.JavaCameraView;
import org.opencv.android.LoaderCallbackInterface;
import org.opencv.android.OpenCVLoader;
import org.opencv.android.Utils;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.MatOfFloat;
import org.opencv.core.MatOfInt;
import org.opencv.core.MatOfRect2d;
import org.opencv.core.Point;
import org.opencv.core.Rect2d;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.dnn.Dnn;
import org.opencv.dnn.Net;
import org.opencv.imgproc.Imgproc;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class MyCamera extends CameraActivity implements CameraBridgeViewBase.CvCameraViewListener2 {
static {
System.loadLibrary("native-lib");
}
private JavaCameraView mOpenCvCameraView;
private int M_REQUEST_CODE = 203;
private String[] permissions = {Manifest.permission.CAMERA};
// Initialize OpenCV manager.
private BaseLoaderCallback mLoaderCallback = new BaseLoaderCallback(this) {
@Override
public void onManagerConnected(int status) {
switch (status) {
case LoaderCallbackInterface.SUCCESS: {
mOpenCvCameraView.enableView();
break;
}
default:
break;
}
}
};
private MediaPlayer mediaPlayer;
private Mat src;
private Net net;
final double IN_SCALE_FACTOR = 0.0078125;
final double MEAN_VAL = 127.0;
final double THRESHOLD = 0.2;
final int IN_WIDTH = 300;
final int IN_HEIGHT = 300;
private Bitmap bp;
private long time1=0;
private long time2=0;
@Override
public void onResume() {
super.onResume();
if (!OpenCVLoader.initDebug()) {
} else {
mLoaderCallback.onManagerConnected(LoaderCallbackInterface.SUCCESS);
}
};
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_my_camera);
// 透明状态栏
if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.LOLLIPOP) {
Window window = getWindow();
window.clearFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_STATUS
| WindowManager.LayoutParams.FLAG_TRANSLUCENT_NAVIGATION);
window.getDecorView().setSystemUiVisibility(View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN
| View.SYSTEM_UI_FLAG_LAYOUT_HIDE_NAVIGATION
| View.SYSTEM_UI_FLAG_LAYOUT_STABLE);
window.addFlags(WindowManager.LayoutParams.FLAG_DRAWS_SYSTEM_BAR_BACKGROUNDS);
window.setStatusBarColor(Color.TRANSPARENT);
window.setNavigationBarColor(Color.TRANSPARENT);
}
// Set up camera listener.
mOpenCvCameraView = (JavaCameraView)findViewById(R.id.CameraView);
mOpenCvCameraView.setVisibility(CameraBridgeViewBase.VISIBLE);
mOpenCvCameraView.setCvCameraViewListener(this);
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
requestPermissions(permissions, M_REQUEST_CODE);
}
// mediaPlayer= MediaPlayer.create(getApplicationContext(),R.raw.ye);
// mediaPlayer.start();
}
public void onCameraViewStarted(int width, int height) {
//加载模型
String proto=getPath("mb2-ssd-lite.onnx",this);
net= Dnn.readNetFromONNX(proto);
bp=Bitmap.createBitmap(width,height, Bitmap.Config.ARGB_8888);
}
private List<Integer> music_index=new ArrayList<>();
public Mat onCameraFrame(CameraBridgeViewBase.CvCameraViewFrame inputFrame) {
//局部放大
// mRgba=inputFrame.rgba();
// Size sizeRgba = mRgba.size();
// int rows = (int) sizeRgba.height;
// int cols = (int) sizeRgba.width;
//
// switch (statue){
// case 0:
// //Canny边缘检测
// mRgba = inputFrame.rgba();
// Imgproc.Canny(inputFrame.gray(), mTmp, 80, 100);
// Imgproc.cvtColor(mTmp, mRgba, Imgproc.COLOR_GRAY2RGBA, 4);
// break;
// case 1:
// //ZOOM放大镜
// Mat zoomCorner = mRgba.submat(0, rows / 2 - rows / 10, 0, cols / 2 - cols / 10);
// Mat mZoomWindow = mRgba.submat(rows / 2 - 9 * rows / 100, rows / 2 + 9 * rows / 100, cols / 2 - 9 * cols / 100, cols / 2 + 9 * cols / 100);
// Imgproc.resize(mZoomWindow, zoomCorner, zoomCorner.size());
// Size wsize = mZoomWindow.size();
// Imgproc.rectangle(mZoomWindow, new Point(1, 1), new Point(wsize.width - 2, wsize.height - 2), new Scalar(255, 0, 0, 255), 2);
// zoomCorner.release();
// mZoomWindow.release();
// break;
//
// }
//目标检测
time2=System.currentTimeMillis();
if (time2-time1>5000){
time1=time2;
music_index=new ArrayList<>();
}
src=inputFrame.rgba();
Imgproc.cvtColor(src,src,Imgproc.COLOR_RGBA2RGB);
Utils.matToBitmap(src,bp);
bp=detection(bp);
Utils.bitmapToMat(bp,src);
//开始播放
// mediaPlayer= MediaPlayer.create(getApplicationContext(),R.raw.m2);
// mediaPlayer.start();
List<Integer> index_class=new ArrayList<>();
//取出大于5个下标
for(int x=0;x<number_class.length;x++) {
if (number_class[x] >= 2) {
index_class.add(x);
}
}
if(index_class.size()>0 ){
for(int j:index_class){
if (music_index.contains(j)){
}else {
music_index.add(j);
mediaPlayer= MediaPlayer.create(getApplicationContext(),musices[j]);
mediaPlayer.start();
Log.i("aa",""+music_index);
}
// if (music_index!=j){
//
// }
}
number_class=new int[54];
}
return src;
}
public void onCameraViewStopped() {
// src.release();
}
@Override
public void onPause() {
super.onPause();
if (mOpenCvCameraView != null)
mOpenCvCameraView.disableView();
}
public void onDestroy() {
super.onDestroy();
mOpenCvCameraView.disableView();
}
@Override
protected List<? extends CameraBridgeViewBase> getCameraViewList() {
return Collections.singletonList(mOpenCvCameraView);
}
private static String getPath(String file, Context context) {
AssetManager assetManager = context.getAssets();
BufferedInputStream inputStream = null;
try {
// Read data from assets.
inputStream = new BufferedInputStream(assetManager.open(file));
byte[] data = new byte[inputStream.available()];
inputStream.read(data);
inputStream.close();
// Create copy file in storage.
File outFile = new File(context.getFilesDir(), file);
FileOutputStream os = new FileOutputStream(outFile);
os.write(data);
os.close();
// Return a path to file which may be read in common way.
return outFile.getAbsolutePath();
} catch (IOException ex) {
}
return "";
}
// private static final String TAG = "OpenCV/Sample/MobileNet";
private static final String[] classNames = {"黑桃A",
"黑桃2", "黑桃3", "黑桃4", "黑桃5","黑桃6", "黑桃7", "黑桃8", "黑桃9","黑桃10", "黑桃J", "黑桃Q", "黑桃K",
"红桃A", "红桃2","红桃3","红桃4","红桃5","红桃6","红桃7","红桃8","红桃9","红桃10","红桃J","红桃Q","红桃K",
"方块A","方块2","方块3","方块4","方块5","方块6","方块7","方块8","方块9","方块10","方块J","方块Q","方块K",
"梅花A","梅花2","梅花3","梅花4","梅花5","梅花6","梅花7","梅花8","梅花9","梅花10","梅花J","梅花Q","梅花K",
"小王","大王"};
private static final int[] musices = {R.raw.m1,
R.raw.m2, R.raw.m3, R.raw.m4, R.raw.m5,R.raw.m6, R.raw.m7, R.raw.m8, R.raw.m9,R.raw.m10, R.raw.m11, R.raw.m12,R.raw.m13,
R.raw.m14, R.raw.m15, R.raw.m16, R.raw.m17,R.raw.m18, R.raw.m19, R.raw.m20, R.raw.m21,R.raw.m22, R.raw.m23, R.raw.m24,R.raw.m25,R.raw.m26,
R.raw.m27, R.raw.m28, R.raw.m29, R.raw.m30,R.raw.m31, R.raw.m32, R.raw.m33, R.raw.m34,R.raw.m35, R.raw.m36, R.raw.m37,R.raw.m38,R.raw.m39,
R.raw.m40, R.raw.m41, R.raw.m42, R.raw.m43,R.raw.m44, R.raw.m45, R.raw.m46, R.raw.m47,R.raw.m48, R.raw.m49, R.raw.m50,R.raw.m51,R.raw.m52,
R.raw.m53,R.raw.m54};
public native String stringFromJNI();
//统计预测的结果
private int[] number_class=new int[54];
public Bitmap detection(Bitmap bp){
Canvas can=new Canvas();
Paint p=new Paint();
android.graphics.Bitmap.Config bitmapConfig = bp.getConfig();
bp = bp.copy(bitmapConfig, true);
can=new Canvas(bp);
p.setAntiAlias(true);
//不填充,默认填充
p.setStyle(Paint.Style.STROKE);
//设置线条宽度
p.setStrokeWidth(5);
//设置颜色
p.setColor(0xFF33FFFF);
p.setTextAlign(Paint.Align.LEFT);
p.setTextSize(50);
Mat blob = Dnn.blobFromImage(src, IN_SCALE_FACTOR,
new Size(IN_WIDTH, IN_HEIGHT),
new Scalar(MEAN_VAL, MEAN_VAL, MEAN_VAL), false);
net.setInput(blob);
blob.release();
//获取输出层的名字
List<String> outnames=net.getUnconnectedOutLayersNames();
// Log.i("aa",String.valueOf(outnames));
//创建输出矩阵集合
List<Mat> detections = new ArrayList<Mat>();
net.forward(detections,outnames);
//获取输出的盒子和置信度
Mat scores=detections.get(0);
Mat boxes= detections.get(1);
scores= scores.reshape(1,3000).colRange(1,55);
boxes= boxes.reshape(1,3000);
Size ss=scores.size();
// Log.i("aa",String.valueOf(scores));
// Log.i("aa",String.valueOf(boxes));
List<Rect2d> rect2dList=new ArrayList<>();//box信息
List<Float> confList=new ArrayList<>();//置信度
List<Integer> objIndexList=new ArrayList<>();//对象类别索引
for(int i=0; i<scores.rows();i++){
Mat one_row=scores.rowRange(i,i+1);
Core.MinMaxLocResult max_index=Core.minMaxLoc(one_row);
double max_value=max_index.maxVal;
Point location=max_index.maxLoc;
if(max_value>0.4){
confList.add((float) max_value);
objIndexList.add((int)location.x);
Mat box_one=boxes.rowRange(i,i+1);
float[] aa=new float[4];
box_one.get(0,0,aa);
double x1=aa[0];
double y1=aa[1];
double x2=aa[2];
double y2=aa[3];
rect2dList.add(new Rect2d(x1,y1,x2,y2));
}
}
//去重
//去重后的索引值
MatOfInt index=new MatOfInt();
//转换box的结果集
MatOfRect2d boxe=new MatOfRect2d(rect2dList.toArray(new Rect2d[0]));
//转换置信度结果集
float[] confArr=new float[confList.size()];
for(int j=0;j<confList.size();j++){
confArr[j]=confList.get(j);
}
MatOfFloat con=new MatOfFloat(confArr);
//使用nms去重
Dnn.NMSBoxes(boxe,con,0.4f,0.45f,index);
if (index.empty()){
return bp;
}
//画框
int[] ints=index.toArray();
for(int x:ints){
// Log.i("aa",String.valueOf(x));
double[] aa=new double[4];
boxe.get(x,0,aa);
//Log.i("aa",String.valueOf(aa[0]));
//Imgproc.rectangle(src,new Point(aa[0]*src.width(),aa[1]*src.height()-70),new Point(aa[2]*src.width()+200,aa[1]*src.height()),new Scalar(255,255,255),-1);
//Imgproc.rectangle(src,new Point(aa[0]*src.width(),aa[1]*src.height()),new Point(aa[2]*src.width(),aa[3]*src.height()),new Scalar(255,255,0),10);
//Imgproc.putText(src,""+classNames[objIndexList.get(x)]+":"+confList.get(x),new Point(aa[0]*src.width(),aa[1]*src.height()),Imgproc.FONT_HERSHEY_SIMPLEX, 3, new Scalar(0, 0, 0));
//画框
can.drawRect((float) aa[0]*src.width(),(float) aa[1]*src.height(),(float)aa[2]*src.width(),(float)aa[3]*src.height(),p);
//绘制填充框
p.setStyle(Paint.Style.FILL);
p.setColor(0xFFFFCC00);
can.drawRect((float) aa[0]*src.width(),(float) aa[1]*src.height()-60,(float)aa[2]*src.width()+150,(float)aa[1]*src.height(),p);
//写字
p.setColor(0xFFFF0000);
can.drawText(classNames[objIndexList.get(x)]+": "+String.format("%.3f", confList.get(x)),(float)aa[0]*src.width(),(float) aa[1]*src.height()-10,p);
p.setStyle(Paint.Style.STROKE);
p.setColor(0xFF33FFFF);
//统计类别信息
number_class[objIndexList.get(x)]+=1;
}
scores.release();
boxe.release();
index.release();
boxes.release();
con.release();
rect2dList.clear();
confList.clear();
objIndexList.clear();
return bp;
}
}
布局代码:
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:opencv="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:visibility="visible">
<org.opencv.android.JavaCameraView
android:id="@+id/CameraView"
android:layout_width="match_parent"
android:layout_height="match_parent"
opencv:camera_id="any"
opencv:show_fps="true" />
</RelativeLayout>