Skip to content

Commit

Permalink
merge pytorch-android.
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghao15536870732 committed Aug 3, 2020
1 parent d829b08 commit d4f6272
Show file tree
Hide file tree
Showing 24 changed files with 1,253 additions and 52 deletions.
20 changes: 12 additions & 8 deletions app/build.gradle
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
apply plugin: 'com.android.application'

android {
compileSdkVersion 27
compileSdkVersion 28
defaultConfig {
applicationId "com.example.ywang.diseaseidentification"
minSdkVersion 16
targetSdkVersion 27
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName "1.0"
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
Expand All @@ -24,25 +24,26 @@ android {

dependencies {
implementation fileTree(include: ['*.jar'], dir: 'libs')
implementation 'com.android.support:appcompat-v7:27.1.1'
//noinspection GradleCompatible
implementation 'com.android.support:appcompat-v7:28.0.0'
implementation 'com.android.support.constraint:constraint-layout:1.1.3'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'com.android.support.test:runner:1.0.2'
androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
implementation 'com.android.support:support-v4:27.1.1'
implementation 'com.android.support:appcompat-v7:27.1.1'
implementation 'com.android.support:support-v4:28.0.0'
implementation 'com.android.support:appcompat-v7:28.0.0'
implementation 'com.github.forvv231:EasyNavigation:1.0.3'
implementation 'com.google.code.gson:gson:2.8.5'
//悬浮按钮
implementation 'com.getbase:floatingactionbutton:1.10.1'
implementation 'com.github.bumptech.glide:glide:4.5.0'
implementation 'com.android.support:design:27.1.1'
implementation 'com.android.support:design:28.0.0'
//侧滑栏
implementation 'com.mxn.soul:flowingdrawer-core:2.1.0'
implementation 'com.nineoldandroids:library:2.4.0'
//圆角图片
implementation 'de.hdodenhof:circleimageview:3.0.1'
implementation 'com.android.support:cardview-v7:27.1.1'
implementation 'com.android.support:cardview-v7:28.0.0'
//图片选择库
implementation 'com.github.LuckSiege.PictureSelector:picture_library:v2.2.3'
//多标签选择
Expand All @@ -63,4 +64,7 @@ dependencies {
implementation 'com.squareup.okhttp3:okhttp:3.4.1'
implementation 'com.google.code.gson:gson:2.7'
implementation 'org.jsoup:jsoup:1.9.2'

implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
8 changes: 7 additions & 1 deletion app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
<uses-permission android:name="android.permission.INTERNET" />
<!-- 允许程序设置内置sd卡的写权限 -->
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.WRITE_SETTINGS"
tools:ignore="ProtectedPermissions" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<!-- 允许程序获取网络状态 -->
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
Expand Down Expand Up @@ -46,6 +48,8 @@
android:supportsRtl="true"
android:theme="@style/AppTheme"
tools:targetApi="n">

<uses-library android:name="org.apache.http.legacy" android:required="false" />
<activity
android:name=".view.activity.GuideActivity"
android:clipToPadding="true"
Expand Down Expand Up @@ -123,7 +127,9 @@
android:name=".view.activity.NearByActivity"
android:theme="@style/AppTheme"
android:windowSoftInputMode="stateHidden|stateUnchanged" />
<activity android:name=".view.activity.DiseaseDetailActivity"></activity>
<activity android:name=".view.activity.DiseaseDetailActivity"/>
<activity android:name=".view.activity.DetailActivity"/>
<activity android:name=".view.activity.MainResultActivity"/>
</application>

</manifest>
Binary file added app/src/main/assets/resnet50.pt
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.example.ywang.diseaseidentification.adapter;

import android.support.v7.widget.CardView;

public interface CardAdapter {
int MAX_ELEVATION_FACTOR = 10;
float getBaseElevation();
CardView getCardViewAt(int position);
int getCount();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package com.example.ywang.diseaseidentification.adapter;

import android.content.Context;
import android.content.Intent;
import android.support.annotation.NonNull;
import android.support.v4.view.PagerAdapter;
import android.support.v7.widget.CardView;
import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.Button;
import android.widget.TextView;

import com.example.ywang.diseaseidentification.R;
import com.example.ywang.diseaseidentification.bean.CardItem;
import com.example.ywang.diseaseidentification.view.activity.DetailActivity;

import java.util.ArrayList;
import java.util.List;

public class CardPagerAdapter extends PagerAdapter implements CardAdapter {

private Context mContext;
private List<CardView> mViews;
private List<CardItem> mData;
private float mBaseElevation;

public CardPagerAdapter(Context context) {
mContext = context;
mData = new ArrayList<>();
mViews = new ArrayList<>();
}

public void addCardItem(CardItem item) {
mViews.add(null);
mData.add(item);
}

public float getBaseElevation() {
return mBaseElevation;
}

@Override
public CardView getCardViewAt(int position) {
return mViews.get(position);
}

@Override
public int getCount() {
return mData.size();
}

@Override
public boolean isViewFromObject(@NonNull View view, @NonNull Object object) {
return view == object;
}

@NonNull
@Override
public Object instantiateItem(@NonNull ViewGroup container, int position) {
View view = LayoutInflater.from(container.getContext()).inflate(R.layout.card_item_main, container, false);
container.addView(view);
bind(mData.get(position), view);
CardView cardView = view.findViewById(R.id.cardView);

if (mBaseElevation == 0) {
mBaseElevation = cardView.getCardElevation();
}

cardView.setMaxCardElevation(mBaseElevation * MAX_ELEVATION_FACTOR);
mViews.set(position, cardView);
return view;
}

@Override
public void destroyItem(@NonNull ViewGroup container, int position, @NonNull Object object) {
container.removeView((View) object);
mViews.set(position, null);
}

private void bind(final CardItem item, View view) {
TextView titleTextView = view.findViewById(R.id.titleTextView);
TextView contentTextView = view.findViewById(R.id.contentTextView);
Button reTakeBtn = view.findViewById(R.id.re_take_btn);
Button moreBtn = view.findViewById(R.id.more_btn);
titleTextView.setText(item.getTitle());
contentTextView.setText(item.getText());
if(item.isScore_show()){
view.findViewById(R.id.score_label).setVisibility(View.VISIBLE);
reTakeBtn.setVisibility(View.GONE);
moreBtn.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
if (!item.getText().equals("")){
Intent intent = new Intent(mContext, DetailActivity.class);
intent.putExtra("link",item.getLink());
intent.putExtra("title",item.getTitle());
mContext.startActivity(intent);
}
}
});
}else {
view.findViewById(R.id.score_label).setVisibility(View.GONE);
moreBtn.setText("错误反馈");
reTakeBtn.setVisibility(View.VISIBLE);
reTakeBtn.setText("重新拍照");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.example.ywang.diseaseidentification.bean;

public class CardItem {
private String mTextResource;
private String mTitleResource;
private boolean score_show = true;
private boolean image_show = false;
private String link;

public CardItem(String title, String text,boolean isShow) {
mTitleResource = title;
mTextResource = text;
score_show = isShow;
}

public CardItem(String title, String text,boolean isShow,String href) {
mTitleResource = title;
mTextResource = text;
score_show = isShow;
link = href;
}

public boolean isImage_show() {
return image_show;
}

public void setImage_show(boolean image_show) {
this.image_show = image_show;
}

public String getLink() {
return link;
}

public void setLink(String link) {
this.link = link;
}

public boolean isScore_show() {
return score_show;
}

public void setScore_show(boolean score_show) {
this.score_show = score_show;
}

public String getText() {
return mTextResource;
}

public String getTitle() {
return mTitleResource;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

public class MapImg {
private int photoId;
private int userId;
private String userId;
private String diseaseName;
private Double lat;
private Double log;
Expand All @@ -16,11 +16,11 @@ public void setPhotoId(int photoId) {
this.photoId = photoId;
}

public int getUserId() {
public String getUserId() {
return userId;
}

public void setUserId(int userId) {
public void setUserId(String userId) {
this.userId = userId;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.example.ywang.diseaseidentification.utils;

import android.graphics.Bitmap;
import android.util.Log;

import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.util.Arrays;

public class Classifier {

Module model;
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};

public Classifier(String modelPath) {
model = Module.load(modelPath);
}

public void setMeanAndStd(float[] mean, float[] std) {
this.mean = mean;
this.std = std;
}

public Tensor preprocess(Bitmap bitmap, int size) {
bitmap = Bitmap.createScaledBitmap(bitmap, size, size, false);
return TensorImageUtils.bitmapToFloat32Tensor(bitmap, this.mean, this.std);
}

public int argMax(float[] inputs) {
int maxIndex = -1;
float maxvalue = 0.0f;
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] > maxvalue) {
maxIndex = i;
maxvalue = inputs[i];
}
}
return maxIndex;
}

public String predict(Bitmap bitmap) {
Tensor tensor = preprocess(bitmap, 325);
IValue inputs = IValue.from(tensor);
Tensor outputs = model.forward(inputs).toTensor();
float[] scores = outputs.getDataAsFloatArray();
Log.e("score", Arrays.toString(scores));

StringBuilder predict_result = new StringBuilder();
while (argMax(scores) != -1){
int classIndex = argMax(scores);
predict_result.append(Constants.DISEASE_CLASSES[classIndex]).append(";").append(scores[classIndex]).append(";");
scores[classIndex] = 0.0f;
}
return predict_result.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.example.ywang.diseaseidentification.utils;

public class Constants {
public static String[] DISEASE_CLASSES = new String[]{
"苹果健康", "苹果黑星病一般", "苹果黑星病严重", "苹果灰斑病", "苹果雪松锈病一般",
"苹果雪松锈病严重", "樱桃健康", "樱桃白粉病一般", "樱桃白粉病严重", "玉米健康",
"玉米灰斑病一般", "玉米灰斑病严重", "玉米锈病一般", "玉米锈病严重", "玉米叶斑病一般",
"玉米叶斑病严重", "玉米花叶病毒病", "葡萄健康", "葡萄黑腐病一般", "葡萄黑腐病严重",
"葡萄轮斑病一般", "葡萄轮斑病严重", "葡萄褐斑病一般", "葡萄褐斑病严重", "柑桔健康",
"柑桔黄龙病一般", "柑桔黄龙病严重", "桃健康", "桃疮痂病一般", "桃疮痂病严重",
"辣椒健康", "辣椒疮痂病一般", "辣椒疮痂病严重", "马铃薯健康", "马铃薯早疫病一般",
"马铃薯早疫病严重", "马铃薯晚疫病一般", "马铃薯晚疫病严重", "草莓健康", "草莓叶枯病一般",
"草莓叶枯病严重", "番茄健康", "番茄白粉病一般", "番茄白粉病严重", "番茄疮痂病一般",
"番茄疮痂病严重", "番茄早疫病一般", "番茄早疫病严重", "番茄晚疫病菌一般", "番茄晚疫病菌严重",
"番茄叶霉病一般", "番茄叶霉病严重", "番茄斑点病一般", "番茄斑点病严重", "番茄斑枯病一般",
"番茄斑枯病严重", "番茄红蜘蛛损伤一般", "番茄红蜘蛛损伤严重", "番茄黄化曲叶病毒病一般", "番茄黄化曲叶病毒病严重",
"番茄花叶病毒病"
};
}
Loading

0 comments on commit d4f6272

Please sign in to comment.