Skip to content

Commit

Permalink
Merge pull request tensorflow#359 from khanhlvg:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 409493056
  • Loading branch information
copybara-github committed Nov 12, 2021
2 parents c930d0a + 7a94266 commit 1f7a780
Show file tree
Hide file tree
Showing 10 changed files with 680 additions and 50 deletions.
10 changes: 10 additions & 0 deletions lite/examples/pose_estimation/android/app/download.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ task downloadMovenetThunderModel(type: DownloadUrlTask) {
target = file("src/main/assets/movenet_thunder.tflite")
}

task downloadMovenetMultiPoseModel(type: DownloadUrlTask) {
def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/multipose/lightning/tflite/float16/1?lite-format=tflite"
doFirst {
println "Downloading ${modelMovenetThunderDownloadUrl}"
}
sourceUrl = "${modelMovenetThunderDownloadUrl}"
target = file("src/main/assets/movenet_multipose_fp16.tflite")
}

task downloadPoseClassifierModel(type: DownloadUrlTask) {
def modelPoseClassifierDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/pose_classifier/yoga_classifier.tflite"
doFirst {
Expand All @@ -39,6 +48,7 @@ task downloadModel {
dependsOn downloadMovenetLightningModel
dependsOn downloadMovenetThunderModel
dependsOn downloadPoseClassifierModel
dependsOn downloadMovenetMultiPoseModel
}

class DownloadUrlTask extends DefaultTask {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.tensorflow.lite.examples.poseestimation.ml

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Canvas
import android.graphics.PointF
import androidx.test.platform.app.InstrumentationRegistry
import com.google.common.truth.Truth.assertThat
Expand Down Expand Up @@ -88,8 +89,25 @@ object EvaluationUtils {
return data
}


/**
* Calculate the distance between two points
*/
private fun distance(point1: PointF, point2: PointF): Float {
return ((point1.x - point2.x).pow(2) + (point1.y - point2.y).pow(2)).pow(0.5f)
}

/**
* Concatenate images of same height horizontally
*/
fun hConcat(image1: Bitmap, image2: Bitmap): Bitmap {
if (image1.height != image2.height) {
throw Exception("Input images are not same height.")
}
val finalBitmap =
Bitmap.createBitmap(image1.width + image2.width, image1.height, Bitmap.Config.ARGB_8888)
val canvas = Canvas(finalBitmap)
canvas.drawBitmap(image1, 0f, 0f, null)
canvas.drawBitmap(image2, image1.width.toFloat(), 0f, null)
return finalBitmap
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/

package org.tensorflow.lite.examples.poseestimation.ml

import android.content.Context
import android.graphics.Bitmap
import android.graphics.PointF
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.tensorflow.lite.examples.poseestimation.data.BodyPart
import org.tensorflow.lite.examples.poseestimation.data.Device
import org.tensorflow.lite.examples.poseestimation.ml.MoveNetMultiPose
import org.tensorflow.lite.examples.poseestimation.ml.Type

@RunWith(AndroidJUnit4::class)
class MovenetMultiPoseTest {
companion object {
private const val TEST_INPUT_IMAGE1 = "image1.png"
private const val TEST_INPUT_IMAGE2 = "image2.jpg"
private const val ACCEPTABLE_ERROR = 17f
}

private lateinit var poseDetector: MoveNetMultiPose
private lateinit var appContext: Context
private lateinit var inputFinal: Bitmap
private lateinit var expectedDetectionResult: List<Map<BodyPart, PointF>>

@Before
fun setup() {
appContext = InstrumentationRegistry.getInstrumentation().targetContext
poseDetector = MoveNetMultiPose.create(appContext, Device.CPU, Type.Dynamic)
val input1 = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE1)
val input2 = EvaluationUtils.loadBitmapAssetByName(TEST_INPUT_IMAGE2)
inputFinal = EvaluationUtils.hConcat(input1, input2)
expectedDetectionResult =
EvaluationUtils.loadCSVAsset("pose_landmark_truth.csv")

// update coordination of the pose_landmark_truth.csv corresponding to the new input image
for ((_, value) in expectedDetectionResult[1]) {
value.x = value.x + input1.width
}
}

@Test
fun testPoseEstimateResult() {
val persons = poseDetector.estimatePoses(inputFinal)
assert(persons.size == 2)

// Sort the results so that the person on the right side come first.
val sortedPersons = persons.sortedBy { it.boundingBox?.left }

EvaluationUtils.assertPoseDetectionResult(
sortedPersons[0],
expectedDetectionResult[0],
ACCEPTABLE_ERROR
)

EvaluationUtils.assertPoseDetectionResult(
sortedPersons[1],
expectedDetectionResult[1],
ACCEPTABLE_ERROR
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class VisualizationTest {
fun testPosenet() {
val poseDetector = PoseNet.create(appContext, Device.CPU)
val person = poseDetector.estimatePoses(inputBitmap)[0]
val outputBitmap = VisualizationUtils.drawBodyKeypoints(inputBitmap, person)
val outputBitmap = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person))
assertThat(outputBitmap).isNotNull()
}

Expand All @@ -64,7 +64,7 @@ class VisualizationTest {
poseDetector.estimatePoses(inputBitmap)
poseDetector.estimatePoses(inputBitmap)
val person2 = poseDetector.estimatePoses(inputBitmap)[0]
val outputBitmap2 = VisualizationUtils.drawBodyKeypoints(inputBitmap, person2)
val outputBitmap2 = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person2))
assertThat(outputBitmap2).isNotNull()
}

Expand All @@ -76,7 +76,7 @@ class VisualizationTest {
poseDetector.estimatePoses(inputBitmap)
poseDetector.estimatePoses(inputBitmap)
val person = poseDetector.estimatePoses(inputBitmap)[0]
val outputBitmap = VisualizationUtils.drawBodyKeypoints(inputBitmap, person)
val outputBitmap = VisualizationUtils.drawBodyKeypoints(inputBitmap, arrayListOf(person))
assertThat(outputBitmap).isNotNull()
}
}
Loading

0 comments on commit 1f7a780

Please sign in to comment.