Skip to content

Commit

Permalink
[rpc] use callback func to do send & recv (apache#4147)
Browse files Browse the repository at this point in the history
* [rpc] use callback func to do send & recv. don't get fd from sock as it is deprecated in java

* fix java build

* fix min/max macro define in windows

* keep the old rpc setup for py

* add doc for CallbackChannel
  • Loading branch information
yzhliu authored and tqchen committed Oct 23, 2019
1 parent a740423 commit 5408d3a
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
private final String host;
private final int port;
private final String key;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;

private volatile Socket currSocket = new Socket();
private Runnable callback;
Expand All @@ -40,14 +39,11 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
* @param host Proxy server host.
* @param port Proxy server port.
* @param key Proxy server key.
* @param sockFdGetter Method to get file descriptor from Java socket.
*/
public ConnectProxyServerProcessor(String host, int port, String key,
SocketFileDescriptorGetter sockFdGetter) {
public ConnectProxyServerProcessor(String host, int port, String key) {
this.host = host;
this.port = port;
this.key = "server:" + key;
socketFileDescriptorGetter = sockFdGetter;
}

/**
Expand All @@ -70,8 +66,8 @@ public void setStartTimeCallback(Runnable callback) {
try {
SocketAddress address = new InetSocketAddress(host, port);
currSocket.connect(address, 6000);
InputStream in = currSocket.getInputStream();
OutputStream out = currSocket.getOutputStream();
final InputStream in = currSocket.getInputStream();
final OutputStream out = currSocket.getOutputStream();
out.write(Utils.toBytes(RPC.RPC_MAGIC));
out.write(Utils.toBytes(key.length()));
out.write(Utils.toBytes(key));
Expand All @@ -91,11 +87,10 @@ public void setStartTimeCallback(Runnable callback) {
if (callback != null) {
callback.run();
}
final int sockFd = socketFileDescriptorGetter.get(currSocket);
if (sockFd != -1) {
new NativeServerLoop(sockFd).run();
System.err.println("Finish serving " + address);
}

SocketChannel sockChannel = new SocketChannel(currSocket);
new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
System.err.println("Finish serving " + address);
} catch (Throwable e) {
e.printStackTrace();
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
*/
public class ConnectTrackerServerProcessor implements ServerProcessor {
private ServerSocket server;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private final String trackerHost;
private final int trackerPort;
// device key
Expand All @@ -62,10 +61,11 @@ public class ConnectTrackerServerProcessor implements ServerProcessor {
* @param trackerHost Tracker host.
* @param trackerPort Tracker port.
* @param key Device key.
* @param sockFdGetter Method to get file descriptor from Java socket.
* @param watchdog watch for timeout, etc.
* @throws java.io.IOException when socket fails to open.
*/
public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String key,
SocketFileDescriptorGetter sockFdGetter, RPCWatchdog watchdog) throws IOException {
RPCWatchdog watchdog) throws IOException {
while (true) {
try {
this.server = new ServerSocket(serverPort);
Expand All @@ -81,7 +81,6 @@ public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String
}
}
System.err.println("using port: " + serverPort);
this.socketFileDescriptorGetter = sockFdGetter;
this.trackerHost = trackerHost;
this.trackerPort = trackerPort;
this.key = key;
Expand Down Expand Up @@ -163,11 +162,9 @@ public String getMatchKey() {
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
// received timeout in seconds
watchdog.startTimeout(timeout * 1000);
final int sockFd = socketFileDescriptorGetter.get(socket);
if (sockFd != -1) {
new NativeServerLoop(sockFd).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
}
SocketChannel sockChannel = new SocketChannel(socket);
new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
Utils.closeQuietly(socket);
} catch (ConnectException e) {
// if the tracker connection failed, wait a bit before retrying
Expand Down
13 changes: 8 additions & 5 deletions jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,25 @@
* Call native ServerLoop on socket file descriptor.
*/
public class NativeServerLoop implements Runnable {
private final int sockFd;
private final Function fsend;
private final Function frecv;

/**
* Constructor for NativeServerLoop.
* @param nativeSockFd native socket file descriptor.
* @param fsend socket.send function.
* @param frecv socket.recv function.
*/
public NativeServerLoop(final int nativeSockFd) {
sockFd = nativeSockFd;
public NativeServerLoop(final Function fsend, final Function frecv) {
this.fsend = fsend;
this.frecv = frecv;
}

@Override public void run() {
File tempDir = null;
try {
tempDir = serverEnv();
System.err.println("starting server loop...");
RPC.getApi("_ServerLoop").pushArg(sockFd).invoke();
RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
System.err.println("done server loop...");
} catch (IOException e) {
e.printStackTrace();
Expand Down
2 changes: 2 additions & 0 deletions jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ public void upload(byte[] data, String target) {
* Upload file to remote runtime temp folder.
* @param data The file in local to upload.
* @param target The path in remote.
* @throws java.io.IOException for network failure.
*/
public void upload(File data, String target) throws IOException {
byte[] blob = getBytesFromFile(data);
Expand All @@ -209,6 +210,7 @@ public void upload(File data, String target) throws IOException {
/**
* Upload file to remote runtime temp folder.
* @param data The file in local to upload.
* @throws java.io.IOException for network failure.
*/
public void upload(File data) throws IOException {
upload(data, data.getName());
Expand Down
49 changes: 3 additions & 46 deletions jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,12 @@

package ml.dmlc.tvm.rpc;

import sun.misc.SharedSecrets;

import java.io.FileDescriptor;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;

/**
* RPC Server.
*/
public class Server {
private static SocketFileDescriptorGetter defaultSocketFdGetter
= new SocketFileDescriptorGetter() {
@Override public int get(Socket socket) {
try {
InputStream is = socket.getInputStream();
FileDescriptor fd = ((FileInputStream) is).getFD();
return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd);
} catch (IOException e) {
e.printStackTrace();
return -1;
}
}
};
private final WorkerThread worker;

private static class WorkerThread extends Thread {
Expand Down Expand Up @@ -72,35 +53,10 @@ public void terminate() {
/**
* Start a standalone server.
* @param serverPort Port.
* @param socketFdGetter Method to get system file descriptor of the server socket.
* @throws IOException if failed to bind localhost:port.
*/
public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException {
worker = new WorkerThread(new StandaloneServerProcessor(serverPort, socketFdGetter));
}

/**
* Start a standalone server.
* Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess
* to get file descriptor for the socket.
* @param serverPort Port.
* @throws IOException if failed to bind localhost:port.
*/
public Server(int serverPort) throws IOException {
this(serverPort, defaultSocketFdGetter);
}

/**
* Start a server connected to proxy.
* @param proxyHost The proxy server host.
* @param proxyPort The proxy server port.
* @param key The key to identify the server.
* @param socketFdGetter Method to get system file descriptor of the server socket.
*/
public Server(String proxyHost, int proxyPort, String key,
SocketFileDescriptorGetter socketFdGetter) {
worker = new WorkerThread(
new ConnectProxyServerProcessor(proxyHost, proxyPort, key, socketFdGetter));
worker = new WorkerThread(new StandaloneServerProcessor(serverPort));
}

/**
Expand All @@ -112,7 +68,8 @@ public Server(String proxyHost, int proxyPort, String key,
* @param key The key to identify the server.
*/
public Server(String proxyHost, int proxyPort, String key) {
this(proxyHost, proxyPort, key, defaultSocketFdGetter);
worker = new WorkerThread(
new ConnectProxyServerProcessor(proxyHost, proxyPort, key));
}

/**
Expand Down
49 changes: 49 additions & 0 deletions jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketChannel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package ml.dmlc.tvm.rpc;

import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.TVMValue;
import ml.dmlc.tvm.TVMValueBytes;

import java.io.IOException;
import java.net.Socket;

public class SocketChannel {
private final Socket socket;

SocketChannel(Socket sock) {
socket = sock;
}

private Function fsend = Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
byte[] data = args[0].asBytes();
try {
socket.getOutputStream().write(data);
} catch (IOException e) {
e.printStackTrace();
return -1;
}
return data.length;
}
});

private Function frecv = Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
long size = args[0].asLong();
try {
return new TVMValueBytes(Utils.recvAll(socket.getInputStream(), (int) size));
} catch (IOException e) {
e.printStackTrace();
return -1;
}
}
});

public Function getFsend() {
return fsend;
}

public Function getFrecv() {
return frecv;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,9 @@
*/
public class StandaloneServerProcessor implements ServerProcessor {
private final ServerSocket server;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;

public StandaloneServerProcessor(int serverPort,
SocketFileDescriptorGetter sockFdGetter) throws IOException {
public StandaloneServerProcessor(int serverPort) throws IOException {
this.server = new ServerSocket(serverPort);
this.socketFileDescriptorGetter = sockFdGetter;
}

@Override public void terminate() {
Expand All @@ -46,9 +43,9 @@ public StandaloneServerProcessor(int serverPort,

@Override public void run() {
try {
Socket socket = server.accept();
InputStream in = socket.getInputStream();
OutputStream out = socket.getOutputStream();
final Socket socket = server.accept();
final InputStream in = socket.getInputStream();
final OutputStream out = socket.getOutputStream();
int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
if (magic != RPC.RPC_MAGIC) {
Utils.closeQuietly(socket);
Expand All @@ -66,12 +63,10 @@ public StandaloneServerProcessor(int serverPort,
out.write(Utils.toBytes(serverKey));
}

SocketChannel sockChannel = new SocketChannel(socket);
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
final int sockFd = socketFileDescriptorGetter.get(socket);
if (sockFd != -1) {
new NativeServerLoop(sockFd).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
}
new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
Utils.closeQuietly(socket);
} catch (Throwable e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package ml.dmlc.tvm.contrib;

import ml.dmlc.tvm.*;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.NDArray;
import ml.dmlc.tvm.TVMContext;
import ml.dmlc.tvm.TestUtils;
import ml.dmlc.tvm.rpc.Client;
import ml.dmlc.tvm.rpc.RPCSession;
import ml.dmlc.tvm.rpc.Server;
Expand Down
4 changes: 2 additions & 2 deletions jvm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@
<artifactId>maven-compiler-plugin</artifactId>
<version>3.3</version>
<configuration>
<source>1.6</source>
<target>1.6</target>
<source>1.7</source>
<target>1.7</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/rpc/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def call_handler(self, args):
port, matchkey = args[2]
self.pending_matchkeys.add(matchkey)
# got custom address (from rpc server)
if args[3] is not None:
if len(args) >= 4 and args[3] is not None:
value = (self, args[3], port, matchkey)
else:
value = (self, self._addr[0], port, matchkey)
Expand Down
2 changes: 2 additions & 0 deletions src/common/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
#define TVM_COMMON_SOCKET_H_

#if defined(_WIN32)
#define NOMINMAX
#include <winsock2.h>
#include <ws2tcpip.h>
#undef NOMINMAX
using ssize_t = int;
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
Expand Down
Loading

0 comments on commit 5408d3a

Please sign in to comment.