Skip to content

Commit

Permalink
working on improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jbilcke-hf committed Oct 8, 2024
1 parent 6525751 commit e123fec
Show file tree
Hide file tree
Showing 17 changed files with 1,372 additions and 1,218 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
title: FacePoke
emoji: 💬
emoji: 🙂‍↔️👈
colorFrom: yellow
colorTo: red
sdk: docker
Expand Down Expand Up @@ -115,6 +115,14 @@ The project structure is organized as follows:
- `src/`: TypeScript source files.
- `public/`: Static assets and built files.

### Increasing the framerate

I am testing various things to increase the framerate.

One project is to only transmit the modified head, instead of the whole image.

Another one is to automatically adapt to the server and network speed.

## Contributing

Contributions to FacePoke are welcome! Please read our [Contributing Guidelines](CONTRIBUTING.md) for details on how to submit pull requests, report issues, or request features.
Expand Down
102 changes: 34 additions & 68 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import asyncio
from aiohttp import web, WSMsgType
import json
from json import JSONEncoder
import numpy as np
import uuid
import logging
import os
Expand All @@ -18,16 +20,18 @@
import io

from PIL import Image

# by popular demand, let's add support for avif
import pillow_avif

# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set asyncio logger to DEBUG level
logging.getLogger("asyncio").setLevel(logging.DEBUG)
#logging.getLogger("asyncio").setLevel(logging.INFO)

logger.debug(f"Python version: {sys.version}")
#logger.debug(f"Python version: {sys.version}")

# SIGSEGV handler
def SIGSEGV_signal_arises(signalNum, stack):
Expand All @@ -43,89 +47,51 @@ def SIGSEGV_signal_arises(signalNum, stack):
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
MODELS_DIR = os.path.join(DATA_ROOT, "models")

image_cache: Dict[str, Image.Image] = {}
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(NumpyEncoder, self).default(obj)

async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
"""
Handle WebSocket connections for the FacePoke application.
Args:
request (web.Request): The incoming request object.
Returns:
web.WebSocketResponse: The WebSocket response object.
"""
ws = web.WebSocketResponse()
await ws.prepare(request)
engine = request.app['engine']
try:
#logger.info("New WebSocket connection established")

while True:
msg = await ws.receive()

if msg.type == WSMsgType.TEXT:
data = json.loads(msg.data)

# let's not log user requests, they are heavy
#logger.debug(f"Received message: {data}")

if data['type'] == 'modify_image':
uuid = data.get('uuid')
if not uuid:
logger.warning("Received message without UUID")
if msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
#logger.warning(f"WebSocket connection closed: {msg.type}")
break

await handle_modify_image(request, ws, data, uuid)
try:
if msg.type == WSMsgType.BINARY:
res = await engine.load_image(msg.data)
json_res = json.dumps(res, cls=NumpyEncoder)
await ws.send_str(json_res)

elif msg.type == WSMsgType.TEXT:
data = json.loads(msg.data)
webp_bytes = engine.transform_image(data.get('hash'), data.get('params'))
await ws.send_bytes(webp_bytes)

elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
#logger.warning(f"WebSocket connection closed: {msg.type}")
break
except Exception as e:
logger.error(f"Error in engine: {str(e)}")
logger.exception("Full traceback:")
await ws.send_json({"error": str(e)})

except Exception as e:
logger.error(f"Error in websocket_handler: {str(e)}")
logger.exception("Full traceback:")
return ws

async def handle_modify_image(request: web.Request, ws: web.WebSocketResponse, msg: Dict[str, Any], uuid: str):
"""
Handle the 'modify_image' request.
Args:
request (web.Request): The incoming request object.
ws (web.WebSocketResponse): The WebSocket response object.
msg (Dict[str, Any]): The message containing the image or image_hash and modification parameters.
uuid: A unique identifier for the request.
"""
#logger.info("Received modify_image request")
try:
engine = request.app['engine']
image_hash = msg.get('image_hash')

if image_hash:
image_or_hash = image_hash
else:
image_data = msg['image']
image_or_hash = image_data

modified_image_base64 = await engine.modify_image(image_or_hash, msg['params'])

await ws.send_json({
"type": "modified_image",
"image": modified_image_base64,
"image_hash": engine.get_image_hash(image_or_hash),
"success": True,
"uuid": uuid # Include the UUID in the response
})
#logger.info("Successfully sent modified image")
except Exception as e:
#logger.error(f"Error in modify_image: {str(e)}")
await ws.send_json({
"type": "modified_image",
"success": False,
"error": str(e),
"uuid": uuid # Include the UUID even in error responses
})

async def index(request: web.Request) -> web.Response:
"""Serve the index.html file"""
content = open(os.path.join(os.path.dirname(__file__), "public", "index.html"), "r").read()
Expand Down
116 changes: 52 additions & 64 deletions client/src/app.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,88 +4,52 @@ import { Download } from 'lucide-react';
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert';
import { truncateFileName } from './lib/utils';
import { useFaceLandmarkDetection } from './hooks/useFaceLandmarkDetection';
import { PoweredBy } from './components/PoweredBy';
import { About } from './components/About';
import { Spinner } from './components/Spinner';
import { useFacePokeAPI } from './hooks/useFacePokeAPI';
import { Layout } from './layout';
import { useMainStore } from './hooks/useMainStore';
import { convertImageToBase64 } from './lib/convertImageToBase64';

export function App() {
const error = useMainStore(s => s.error);
const setError = useMainStore(s => s.setError);
const imageFile = useMainStore(s => s.imageFile);
const setImageFile = useMainStore(s => s.setImageFile);
const originalImage = useMainStore(s => s.originalImage);
const setOriginalImage = useMainStore(s => s.setOriginalImage);
const isGazingAtCursor = useMainStore(s => s.isGazingAtCursor);
const setIsGazingAtCursor = useMainStore(s => s.setIsGazingAtCursor);
const isFollowingCursor = useMainStore(s => s.isFollowingCursor);
const setIsFollowingCursor = useMainStore(s => s.setIsFollowingCursor);

const previewImage = useMainStore(s => s.previewImage);
const setPreviewImage = useMainStore(s => s.setPreviewImage);
const resetImage = useMainStore(s => s.resetImage);
const setOriginalImageHash = useMainStore(s => s.setOriginalImageHash);
const status = useMainStore(s => s.status);
const blendShapes = useMainStore(s => s.blendShapes);

const {
status,
setStatus,
isDebugMode,
setIsDebugMode,
interruptMessage,
} = useFacePokeAPI()

// State for face detection
const {
canvasRef,
canvasRefCallback,
mediaPipeRef,
faceLandmarks,
isMediaPipeReady,
blendShapes,

setFaceLandmarks,
setBlendShapes,

handleMouseDown,
handleMouseUp,
handleMouseMove,
handleMouseEnter,
handleMouseLeave,
handleTouchStart,
handleTouchMove,
handleTouchEnd,
currentOpacity
} = useFaceLandmarkDetection()

// Refs
const videoRef = useRef<HTMLDivElement>(null);

// Handle file change
const handleFileChange = useCallback(async (event: React.ChangeEvent<HTMLInputElement>) => {
const handleFileChange = useCallback((event: React.ChangeEvent<HTMLInputElement>) => {
const files = event.target.files;
if (files && files[0]) {
setImageFile(files[0]);
setStatus(`File selected: ${truncateFileName(files[0].name, 16)}`);

try {
const image = await convertImageToBase64(files[0]);
setPreviewImage(image);
setOriginalImage(image);
setOriginalImageHash('');
} catch (err) {
console.log(`failed to convert the image: `, err);
setImageFile(null);
setStatus('');
setPreviewImage('');
setOriginalImage('');
setOriginalImageHash('');
setFaceLandmarks([]);
setBlendShapes([]);
}
} else {
setImageFile(null);
setStatus('');
setPreviewImage('');
setOriginalImage('');
setOriginalImageHash('');
setFaceLandmarks([]);
setBlendShapes([]);
}
}, [isMediaPipeReady, setImageFile, setPreviewImage, setOriginalImage, setOriginalImageHash, setFaceLandmarks, setBlendShapes, setStatus]);
setImageFile(files?.[0] || undefined)
}, [setImageFile]);

const handleDownload = useCallback(() => {
if (previewImage) {
Expand Down Expand Up @@ -139,7 +103,7 @@ export function App() {
<div className="mb-5 relative">
<div className="flex flex-row items-center justify-between w-full">
<div className="flex items-center space-x-2">
<div className="relative">
<div className="flex items-center justify-center">
<input
id="imageInput"
type="file"
Expand All @@ -155,7 +119,7 @@ export function App() {
} focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-slate-500 shadow-xl`}
>
<Spinner />
{imageFile ? truncateFileName(imageFile.name, 32) : (isMediaPipeReady ? 'Choose a portrait photo (.jpg, .png, .webp)' : 'Initializing...')}
{imageFile ? truncateFileName(imageFile.name, 32) : (isMediaPipeReady ? 'Choose a portrait photo' : 'Initializing...')}
</label>
</div>
{previewImage && (
Expand All @@ -168,15 +132,38 @@ export function App() {
</button>
)}
</div>
{previewImage && <label className="mt-4 flex items-center">
<input
type="checkbox"
checked={isDebugMode}
onChange={(e) => setIsDebugMode(e.target.checked)}
className="mr-2"
/>
Show face landmarks on hover
</label>}
{previewImage && <div className="flex items-center space-x-2">
{/* experimental features, not active yet */}
{/*
<label className="mt-4 flex items-center">
<input
type="checkbox"
checked={isGazingAtCursor}
onChange={(e) => setIsGazingAtCursor(!isGazingAtCursor)}
className="mr-2"
/>
Autotrack eyes
</label>
<label className="mt-4 flex items-center">
<input
type="checkbox"
checked={isFollowingCursor}
onChange={(e) => setIsFollowingCursor(!isFollowingCursor)}
className="mr-2"
/>
Autotrack head
</label>
*/}
<label className="mt-4 flex items-center">
<input
type="checkbox"
checked={isDebugMode}
onChange={(e) => setIsDebugMode(e.target.checked)}
className="mr-2"
/>
Show face markers
</label>
</div>}
</div>
{previewImage && (
<div className="mt-5 relative shadow-2xl rounded-xl overflow-hidden">
Expand All @@ -188,11 +175,12 @@ export function App() {
<canvas
ref={canvasRefCallback}
className="absolute top-0 left-0 w-full h-full select-none"
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
onMouseDown={handleMouseDown}
onMouseUp={handleMouseUp}
onMouseMove={handleMouseMove}
onTouchStart={handleTouchStart}
onTouchMove={handleTouchMove}
onTouchEnd={handleTouchEnd}
style={{
position: 'absolute',
top: 0,
Expand All @@ -207,7 +195,7 @@ export function App() {
)}
{canDisplayBlendShapes && displayBlendShapes}
</div>
<PoweredBy />
<About />
</Layout>
);
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
export function PoweredBy() {
export function About() {
return (
<div className="flex flex-row items-center justify-center font-sans mt-4 w-full">
{/*<span className="text-neutral-900 text-sm"
<span className="text-neutral-900 text-sm"
style={{ textShadow: "rgb(255 255 255 / 80%) 0px 0px 2px" }}>
Best hosted on
</span>*/}
<span className="mr-1">
Click and drag on the image.
</span>
<span className="ml-2 mr-1">
<img src="/hf-logo.svg" alt="Hugging Face" className="w-5 h-5" />
</span>
<span className="text-neutral-900 text-sm font-semibold"
Expand Down
Loading

0 comments on commit e123fec

Please sign in to comment.