Skip to content

Commit

Permalink
List of included changes imported by Copybara:
Browse files Browse the repository at this point in the history
  - Android: add a kNN-based pose classification example to vision quickstart app
  - Android: add selfie segmentation to vision quickstart app
  - Android: update README for pose classifier and selfie segmentation

PiperOrigin-RevId: 359365704
Change-Id: I252773fa58f2b6dcbaf20dccc1f4af697a57135c
  • Loading branch information
Google ML Kit authored and Chengji Yan committed Feb 26, 2021
1 parent 5a5e92e commit ae9032e
Show file tree
Hide file tree
Showing 29 changed files with 2,312 additions and 116 deletions.
4 changes: 4 additions & 0 deletions android/vision-quickstart/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Features that are included in this Quickstart app:
* [Image Labeling](https://developers.google.com/ml-kit/vision/image-labeling/android) - Label images in real time and static images
* [Custom Image Labeling - Birds](https://developers.google.com/ml-kit/vision/image-labeling/custom-models/android) - Label images of birds with a custom TensorFlow Lite model.
* [Pose Detection](https://developers.google.com/ml-kit/vision/pose-detection/android) - Detect the position of the human body in real time.
* [Selfie Segmentation](https://developers.google.com/ml-kit/vision/selfie-segmentation/android) - Segment people from the background in real time.

<img src="../screenshots/quickstart-picker.png" width="256"/> <img src="../screenshots/quickstart-image-labeling.png" width="256"/> <img src="../screenshots/quickstart-object-detection.png" width="256"/>

Expand Down Expand Up @@ -46,6 +47,9 @@ It uses the camera preview as input and contains these API workflows: Object det
* Show in-frame likelihood -- Displays InFrameLikelihood score for each landmark
* Visualize z value -- Uses different colors to indicate z difference (red: smaller z, blue: larger z)
* Rescale z value for visualization -- Maps the smallest z value to the most red and the largest z value to the most blue. This makes z difference more obvious
* Run classification -- Classify squat and pushup poses. Count reps in streaming mode.
* Selfie Segmentation
* Enable raw size mask -- Asks the segmenter to return the raw size mask which matches the model output size.

### Static Image scenario
The static image scenario is identical to the live camera scenario, but instead relies on images fed into the app through the gallery.
Expand Down
7 changes: 5 additions & 2 deletions android/vision-quickstart/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,12 @@ dependencies {
implementation 'com.google.mlkit:image-labeling-custom:16.3.1'

// Pose detection with default models
implementation 'com.google.mlkit:pose-detection:17.0.1-beta2'
implementation 'com.google.mlkit:pose-detection:17.0.1-beta3'
// Pose detection with accurate models
implementation 'com.google.mlkit:pose-detection-accurate:17.0.1-beta2'
implementation 'com.google.mlkit:pose-detection-accurate:17.0.1-beta3'

// Selfie segmentation
implementation 'com.google.mlkit:segmentation-selfie:16.0.0-beta1'

// -------------------------------------------------------

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import com.google.mlkit.vision.demo.java.labeldetector.LabelDetectorProcessor;
import com.google.mlkit.vision.demo.java.objectdetector.ObjectDetectorProcessor;
import com.google.mlkit.vision.demo.java.posedetector.PoseDetectorProcessor;
import com.google.mlkit.vision.demo.java.segmenter.SegmenterProcessor;
import com.google.mlkit.vision.demo.java.textdetector.TextRecognitionProcessor;
import com.google.mlkit.vision.demo.preference.PreferenceUtils;
import com.google.mlkit.vision.demo.preference.SettingsActivity;
Expand Down Expand Up @@ -94,6 +95,7 @@ public final class CameraXLivePreviewActivity extends AppCompatActivity
private static final String IMAGE_LABELING_CUSTOM = "Custom Image Labeling (Bird)";
private static final String CUSTOM_AUTOML_LABELING = "Custom AutoML Image Labeling (Flower)";
private static final String POSE_DETECTION = "Pose Detection";
private static final String SELFIE_SEGMENTATION = "Selfie Segmentation";

private static final String STATE_SELECTED_MODEL = "selected_model";
private static final String STATE_LENS_FACING = "lens_facing";
Expand Down Expand Up @@ -154,6 +156,7 @@ protected void onCreate(Bundle savedInstanceState) {
options.add(IMAGE_LABELING_CUSTOM);
options.add(CUSTOM_AUTOML_LABELING);
options.add(POSE_DETECTION);
options.add(SELFIE_SEGMENTATION);

// Creating adapter for spinner
ArrayAdapter<String> dataAdapter = new ArrayAdapter<>(this, R.layout.spinner_style, options);
Expand Down Expand Up @@ -379,9 +382,14 @@ private void bindAnalysisUseCase() {
PreferenceUtils.shouldShowPoseDetectionInFrameLikelihoodLivePreview(this);
boolean visualizeZ = PreferenceUtils.shouldPoseDetectionVisualizeZ(this);
boolean rescaleZ = PreferenceUtils.shouldPoseDetectionRescaleZForVisualization(this);
boolean runClassification = PreferenceUtils.shouldPoseDetectionRunClassification(this);
imageProcessor =
new PoseDetectorProcessor(
this, poseDetectorOptions, shouldShowInFrameLikelihood, visualizeZ, rescaleZ);
this, poseDetectorOptions, shouldShowInFrameLikelihood, visualizeZ, rescaleZ,
runClassification, /* isStreamMode = */true);
break;
case SELFIE_SEGMENTATION:
imageProcessor = new SegmenterProcessor(this);
break;
default:
throw new IllegalStateException("Invalid model name");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import com.google.mlkit.vision.demo.java.labeldetector.LabelDetectorProcessor;
import com.google.mlkit.vision.demo.java.objectdetector.ObjectDetectorProcessor;
import com.google.mlkit.vision.demo.java.posedetector.PoseDetectorProcessor;
import com.google.mlkit.vision.demo.java.segmenter.SegmenterProcessor;
import com.google.mlkit.vision.demo.java.textdetector.TextRecognitionProcessor;
import com.google.mlkit.vision.demo.preference.PreferenceUtils;
import com.google.mlkit.vision.demo.preference.SettingsActivity;
Expand Down Expand Up @@ -76,6 +77,7 @@ public final class LivePreviewActivity extends AppCompatActivity
private static final String IMAGE_LABELING_CUSTOM = "Custom Image Labeling (Bird)";
private static final String CUSTOM_AUTOML_LABELING = "Custom AutoML Image Labeling (Flower)";
private static final String POSE_DETECTION = "Pose Detection";
private static final String SELFIE_SEGMENTATION = "Selfie Segmentation";

private static final String TAG = "LivePreviewActivity";
private static final int PERMISSION_REQUESTS = 1;
Expand Down Expand Up @@ -113,6 +115,7 @@ protected void onCreate(Bundle savedInstanceState) {
options.add(IMAGE_LABELING_CUSTOM);
options.add(CUSTOM_AUTOML_LABELING);
options.add(POSE_DETECTION);
options.add(SELFIE_SEGMENTATION);

// Creating adapter for spinner
ArrayAdapter<String> dataAdapter = new ArrayAdapter<>(this, R.layout.spinner_style, options);
Expand Down Expand Up @@ -261,8 +264,13 @@ private void createCameraSource(String model) {
PreferenceUtils.shouldShowPoseDetectionInFrameLikelihoodLivePreview(this);
boolean visualizeZ = PreferenceUtils.shouldPoseDetectionVisualizeZ(this);
boolean rescaleZ = PreferenceUtils.shouldPoseDetectionRescaleZForVisualization(this);
boolean runClassification = PreferenceUtils.shouldPoseDetectionRunClassification(this);
cameraSource.setMachineLearningFrameProcessor(new PoseDetectorProcessor(
this, poseDetectorOptions, shouldShowInFrameLikelihood, visualizeZ, rescaleZ));
this, poseDetectorOptions, shouldShowInFrameLikelihood, visualizeZ, rescaleZ,
runClassification, /* isStreamMode = */true));
break;
case SELFIE_SEGMENTATION:
cameraSource.setMachineLearningFrameProcessor(new SegmenterProcessor(this));
break;
default:
Log.e(TAG, "Unknown model: " + model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import com.google.mlkit.vision.demo.java.labeldetector.LabelDetectorProcessor;
import com.google.mlkit.vision.demo.java.objectdetector.ObjectDetectorProcessor;
import com.google.mlkit.vision.demo.java.posedetector.PoseDetectorProcessor;
import com.google.mlkit.vision.demo.java.segmenter.SegmenterProcessor;
import com.google.mlkit.vision.demo.java.textdetector.TextRecognitionProcessor;
import com.google.mlkit.vision.demo.preference.PreferenceUtils;
import com.google.mlkit.vision.demo.preference.SettingsActivity;
Expand Down Expand Up @@ -76,6 +77,7 @@ public final class StillImageActivity extends AppCompatActivity {
private static final String IMAGE_LABELING_CUSTOM = "Custom Image Labeling (Bird)";
private static final String CUSTOM_AUTOML_LABELING = "Custom AutoML Image Labeling (Flower)";
private static final String POSE_DETECTION = "Pose Detection";
private static final String SELFIE_SEGMENTATION = "Selfie Segmentation";

private static final String SIZE_SCREEN = "w:screen"; // Match screen width
private static final String SIZE_1024_768 = "w:1024"; // ~1024*768 in a normal ratio
Expand Down Expand Up @@ -187,6 +189,7 @@ private void populateFeatureSelector() {
options.add(IMAGE_LABELING_CUSTOM);
options.add(CUSTOM_AUTOML_LABELING);
options.add(POSE_DETECTION);
options.add(SELFIE_SEGMENTATION);

// Creating adapter for featureSpinner
ArrayAdapter<String> dataAdapter = new ArrayAdapter<>(this, R.layout.spinner_style, options);
Expand Down Expand Up @@ -425,9 +428,14 @@ private void createImageProcessor() {
PreferenceUtils.shouldShowPoseDetectionInFrameLikelihoodStillImage(this);
boolean visualizeZ = PreferenceUtils.shouldPoseDetectionVisualizeZ(this);
boolean rescaleZ = PreferenceUtils.shouldPoseDetectionRescaleZForVisualization(this);
boolean runClassification = PreferenceUtils.shouldPoseDetectionRunClassification(this);
imageProcessor =
new PoseDetectorProcessor(
this, poseDetectorOptions, shouldShowInFrameLikelihood, visualizeZ, rescaleZ);
this, poseDetectorOptions, shouldShowInFrameLikelihood, visualizeZ, rescaleZ,
runClassification, /* isStreamMode = */false);
break;
case SELFIE_SEGMENTATION:
imageProcessor = new SegmenterProcessor(this, /* isStreamMode= */ false);
break;
default:
Log.e(TAG, "Unknown selectedMode: " + selectedMode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,70 @@
import com.google.mlkit.vision.common.InputImage;
import com.google.mlkit.vision.demo.GraphicOverlay;
import com.google.mlkit.vision.demo.java.VisionProcessorBase;
import com.google.mlkit.vision.demo.java.posedetector.classification.PoseClassifierProcessor;
import com.google.mlkit.vision.pose.Pose;
import com.google.mlkit.vision.pose.PoseDetection;
import com.google.mlkit.vision.pose.PoseDetector;
import com.google.mlkit.vision.pose.PoseDetectorOptionsBase;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;

/** A processor to run pose detector. */
public class PoseDetectorProcessor extends VisionProcessorBase<Pose> {

public class PoseDetectorProcessor
extends VisionProcessorBase<PoseDetectorProcessor.PoseWithClassification> {
private static final String TAG = "PoseDetectorProcessor";

private final PoseDetector detector;

private final boolean showInFrameLikelihood;
private final boolean visualizeZ;
private final boolean rescaleZForVisualization;
private final boolean runClassification;
private final boolean isStreamMode;
private final Context context;
private final Executor classificationExecutor;

private PoseClassifierProcessor poseClassifierProcessor;
/**
* Internal class to hold Pose and classification results.
*/
protected static class PoseWithClassification {
private final Pose pose;
private final List<String> classificationResult;

public PoseWithClassification(Pose pose, List<String> classificationResult) {
this.pose = pose;
this.classificationResult = classificationResult;
}

public Pose getPose() {
return pose;
}

public List<String> getClassificationResult() {
return classificationResult;
}
}

public PoseDetectorProcessor(
Context context,
PoseDetectorOptionsBase options,
boolean showInFrameLikelihood,
boolean visualizeZ,
boolean rescaleZForVisualization) {
boolean rescaleZForVisualization,
boolean runClassification,
boolean isStreamMode) {
super(context);
this.showInFrameLikelihood = showInFrameLikelihood;
this.visualizeZ = visualizeZ;
this.rescaleZForVisualization = rescaleZForVisualization;
detector = PoseDetection.getClient(options);
this.runClassification = runClassification;
this.isStreamMode = isStreamMode;
this.context = context;
classificationExecutor = Executors.newSingleThreadExecutor();
}

@Override
Expand All @@ -59,15 +96,32 @@ public void stop() {
}

@Override
protected Task<Pose> detectInImage(InputImage image) {
return detector.process(image);
protected Task<PoseWithClassification> detectInImage(InputImage image) {
return detector
.process(image)
.continueWith(
classificationExecutor,
task -> {
Pose pose = task.getResult();
List<String> classificationResult = new ArrayList<>();
if (runClassification) {
if (poseClassifierProcessor == null) {
poseClassifierProcessor = new PoseClassifierProcessor(context, isStreamMode);
}
classificationResult = poseClassifierProcessor.getPoseResult(pose);
}
return new PoseWithClassification(pose, classificationResult);
});
}

@Override
protected void onSuccess(@NonNull Pose pose, @NonNull GraphicOverlay graphicOverlay) {
protected void onSuccess(
@NonNull PoseWithClassification poseWithClassification,
@NonNull GraphicOverlay graphicOverlay) {
graphicOverlay.add(
new PoseGraphic(
graphicOverlay, pose, showInFrameLikelihood, visualizeZ, rescaleZForVisualization));
graphicOverlay, poseWithClassification.pose, showInFrameLikelihood, visualizeZ,
rescaleZForVisualization, poseWithClassification.classificationResult));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public class PoseGraphic extends Graphic {
private static final float DOT_RADIUS = 8.0f;
private static final float IN_FRAME_LIKELIHOOD_TEXT_SIZE = 30.0f;
private static final float STROKE_WIDTH = 10.0f;
private static final float POSE_CLASSIFICATION_TEXT_SIZE = 60.0f;

private final Pose pose;
private final boolean showInFrameLikelihood;
Expand All @@ -46,6 +47,8 @@ public class PoseGraphic extends Graphic {
private float zMin = Float.MAX_VALUE;
private float zMax = Float.MIN_VALUE;

private final List<String> poseClassification;
private final Paint classificationTextPaint;
private final Paint leftPaint;
private final Paint rightPaint;
private final Paint whitePaint;
Expand All @@ -55,13 +58,19 @@ public class PoseGraphic extends Graphic {
Pose pose,
boolean showInFrameLikelihood,
boolean visualizeZ,
boolean rescaleZForVisualization) {
boolean rescaleZForVisualization,
List<String> poseClassification) {
super(overlay);
this.pose = pose;
this.showInFrameLikelihood = showInFrameLikelihood;
this.visualizeZ = visualizeZ;
this.rescaleZForVisualization = rescaleZForVisualization;

this.poseClassification = poseClassification;
classificationTextPaint = new Paint();
classificationTextPaint.setColor(Color.WHITE);
classificationTextPaint.setTextSize(POSE_CLASSIFICATION_TEXT_SIZE);

whitePaint = new Paint();
whitePaint.setStrokeWidth(STROKE_WIDTH);
whitePaint.setColor(Color.WHITE);
Expand All @@ -81,6 +90,18 @@ public void draw(Canvas canvas) {
return;
}

// Draw pose classification text.
float classificationX = POSE_CLASSIFICATION_TEXT_SIZE * 0.5f;
for (int i = 0; i < poseClassification.size(); i++) {
float classificationY = (canvas.getHeight() - POSE_CLASSIFICATION_TEXT_SIZE * 1.5f
* (poseClassification.size() - i));
canvas.drawText(
poseClassification.get(i),
classificationX,
classificationY,
classificationTextPaint);
}

// Draw all the points
for (PoseLandmark landmark : landmarks) {
drawPoint(canvas, landmark, whitePaint);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2020 Google LLC. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.mlkit.vision.demo.java.posedetector.classification;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import static java.util.Collections.max;

/**
* Represents Pose classification result as outputted by {@link PoseClassifier}. Can be manipulated.
*/
public class ClassificationResult {
// For an entry in this map, the key is the class name, and the value is how many times this class
// appears in the top K nearest neighbors. The value is in range [0, K] and could be a float after
// EMA smoothing. We use this number to represent the confidence of a pose being in this class.
private final Map<String, Float> classConfidences;

public ClassificationResult() {
classConfidences = new HashMap<>();
}

public Set<String> getAllClasses() {
return classConfidences.keySet();
}

public float getClassConfidence(String className) {
return classConfidences.containsKey(className) ? classConfidences.get(className) : 0;
}

public String getMaxConfidenceClass() {
return max(
classConfidences.entrySet(),
(entry1, entry2) -> (int) (entry1.getValue() - entry2.getValue()))
.getKey();
}

public void incrementClassConfidence(String className) {
classConfidences.put(className,
classConfidences.containsKey(className) ? classConfidences.get(className) + 1 : 1);
}

public void putClassConfidence(String className, float confidence) {
classConfidences.put(className, confidence);
}
}
Loading

0 comments on commit ae9032e

Please sign in to comment.