Skip to content

Commit

Permalink
Implement http2
Browse files Browse the repository at this point in the history
  • Loading branch information
chhsiao90 committed Mar 13, 2017
1 parent c594c22 commit 73f1d02
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 34 deletions.
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<tcnative.classifier>${os.detected.classifier}</tcnative.classifier>
<guava.version>20.0</guava.version>
<logback.version>1.2.1</logback.version>
<bouncycastle.version>1.54</bouncycastle.version>
</properties>

<build>
Expand Down Expand Up @@ -48,6 +49,11 @@
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcpkix-jdk15on</artifactId>
<version>${bouncycastle.version}</version>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
public class NitmProxy {
public static void main(String[] args) throws Exception {
NitmProxyConfig config = new NitmProxyConfig();
config.setMaxContentLength(4096);
config.setProxyMode(ProxyMode.HTTP);

NioEventLoopGroup bossGroup = new NioEventLoopGroup(1);
NioEventLoopGroup workerGroup = new NioEventLoopGroup();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,21 @@ public class NitmProxyConfig {

private List<Integer> httpsPorts;

private String certFile;

private String keyFile;

private int maxContentLength;

public NitmProxyConfig() {
// Defaults
proxyMode = ProxyMode.HTTP;

httpsPorts = Arrays.asList(443, 8443);
certFile = "server.pem";
keyFile = "key.pem";

maxContentLength = 4096;
maxContentLength = 1024 * 1024;
}

public ProxyMode getProxyMode() {
Expand All @@ -39,6 +45,22 @@ public boolean isTls(int port) {
return httpsPorts.contains(port);
}

public String getCertFile() {
return certFile;
}

public void setCertFile(String certFile) {
this.certFile = certFile;
}

public String getKeyFile() {
return keyFile;
}

public void setKeyFile(String keyFile) {
this.keyFile = keyFile;
}

public int getMaxContentLength() {
return maxContentLength;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@
import com.github.chhsiao.nitm.nitmproxy.NitmProxyConfig;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http2.Http2Codec;
import io.netty.util.ReferenceCountUtil;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DelegatingDecompressorFrameListener;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2FrameLogger;
import io.netty.handler.codec.http2.HttpConversionUtil.ExtensionHeaderNames;
import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandlerBuilder;
import io.netty.handler.codec.http2.InboundHttp2ToHttpAdapterBuilder;
import io.netty.handler.logging.LogLevel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -40,7 +50,20 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
LOGGER.info("{} : handlerAdded", connectionInfo);

ctx.pipeline().addBefore(ctx.name(), null, new Http2Codec(false, new Http2Handler()));
Http2Connection connection = new DefaultHttp2Connection(false);
ChannelHandler http2ConnHandler = new HttpToHttp2ConnectionHandlerBuilder()
.frameListener(new DelegatingDecompressorFrameListener(
connection,
new InboundHttp2ToHttpAdapterBuilder(connection)
.maxContentLength(config.getMaxContentLength())
.propagateSettings(true)
.build()))
.frameLogger(new Http2FrameLogger(LogLevel.DEBUG))
.connection(connection)
.build();
ctx.pipeline()
.addBefore(ctx.name(), null, http2ConnHandler)
.addBefore(ctx.name(), null, new Http2Handler());
}

private class Http2Handler extends ChannelDuplexHandler {
Expand All @@ -51,5 +74,15 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
msg);
outboundChannel.writeAndFlush(msg);
}

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof HttpMessage) {
HttpMessage httpMessage = (HttpMessage) msg;
httpMessage.headers().add(ExtensionHeaderNames.SCHEME.text(), "https");
}

ctx.writeAndFlush(msg, promise);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
import com.github.chhsiao.nitm.nitmproxy.ConnectionInfo;
import com.github.chhsiao.nitm.nitmproxy.NitmProxyConfig;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http2.Http2Codec;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DelegatingDecompressorFrameListener;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2FrameLogger;
import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandlerBuilder;
import io.netty.handler.codec.http2.InboundHttp2ToHttpAdapterBuilder;
import io.netty.handler.logging.LogLevel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -20,7 +26,6 @@ public class Http2FrontendHandler extends ChannelInboundHandlerAdapter {

private NitmProxyConfig config;
private ConnectionInfo connectionInfo;

private Channel outboundChannel;

public Http2FrontendHandler(NitmProxyConfig config, ConnectionInfo connectionInfo, Channel outboundChannel) {
Expand All @@ -33,8 +38,20 @@ public Http2FrontendHandler(NitmProxyConfig config, ConnectionInfo connectionInf
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
LOGGER.info("{} : handlerAdded", connectionInfo);

Http2Connection connection = new DefaultHttp2Connection(true);
ChannelHandler http2ConnHandler = new HttpToHttp2ConnectionHandlerBuilder()
.frameListener(new DelegatingDecompressorFrameListener(
connection,
new InboundHttp2ToHttpAdapterBuilder(connection)
.maxContentLength(config.getMaxContentLength())
.propagateSettings(true)
.build()))
.connection(connection)
.frameLogger(new Http2FrameLogger(LogLevel.DEBUG))
.build();
ctx.pipeline()
.addBefore(ctx.name(), null, new Http2Codec(true, new Http2Handler()));
.addBefore(ctx.name(), null, http2ConnHandler)
.addBefore(ctx.name(), null, new Http2Handler());
}

@Override
Expand All @@ -43,7 +60,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
outboundChannel.close();
}

private class Http2Handler extends ChannelDuplexHandler {
private class Http2Handler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
LOGGER.info("[Client ({})] => [Server ({})] : {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,10 @@
import com.github.chhsiao.nitm.nitmproxy.layer.protocol.http2.Http2BackendHandler;
import com.github.chhsiao.nitm.nitmproxy.layer.protocol.http2.Http2FrontendHandler;
import com.github.chhsiao.nitm.nitmproxy.tls.TlsUtil;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.ssl.ApplicationProtocolNames;
import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler;
import io.netty.handler.ssl.SslHandler;
Expand All @@ -27,7 +20,6 @@
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

public class TlsHandler extends ChannelOutboundHandlerAdapter {
private static final Logger LOGGER = LoggerFactory.getLogger(TlsHandler.class);
Expand Down Expand Up @@ -57,7 +49,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
SslHandler sslHandler = TlsUtil.ctx(config, client, connectionInfo.getServerAddr().getHost()).newHandler(ctx.alloc());
ctx.pipeline()
.addBefore(ctx.name(), null, sslHandler)
.addBefore(ctx.name(), null, new AlpnHandler());
.addBefore(ctx.name(), null, new AlpnHandler(ctx));
} else {
configHttp1(ctx);
}
Expand All @@ -82,6 +74,14 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
}
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
LOGGER.error("{} : TlsHandler(client={}) exceptionCaught, message is {}", connectionInfo, client, cause.getMessage());

outboundChannel.close();
ctx.close();
}

private void flushPendings(ChannelHandlerContext ctx) {
synchronized (pendings) {
Iterator<Object> iterator = pendings.iterator();
Expand All @@ -105,24 +105,29 @@ private void configHttp1(ChannelHandlerContext ctx) {
private void configHttp2(ChannelHandlerContext ctx) {
if (client) {
Http2BackendHandler backendHandler = new Http2BackendHandler(config, connectionInfo, outboundChannel);
ctx.pipeline().replace(this, null, backendHandler);
ctx.pipeline().addBefore(ctx.name(), null, backendHandler);
ctx.pipeline().remove(this);
} else {
Http2FrontendHandler frontendHandler = new Http2FrontendHandler(config, connectionInfo, outboundChannel);
ctx.pipeline().replace(this, null, frontendHandler);
ctx.pipeline().addBefore(ctx.name(), null, frontendHandler);
ctx.pipeline().remove(this);
}
}

private class AlpnHandler extends ApplicationProtocolNegotiationHandler {
private AlpnHandler() {
private ChannelHandlerContext tlsCtx;

private AlpnHandler(ChannelHandlerContext tlsCtx) {
super(ApplicationProtocolNames.HTTP_1_1);
this.tlsCtx = tlsCtx;
}

@Override
protected void configurePipeline(ChannelHandlerContext ctx, String protocol) throws Exception {
if (ApplicationProtocolNames.HTTP_2.equals(protocol)) {
configHttp2(ctx);
configHttp2(tlsCtx);
} else if (ApplicationProtocolNames.HTTP_1_1.equals(protocol)) {
configHttp1(ctx);
configHttp1(tlsCtx);
} else {
throw new IllegalStateException("unknown protocol: " + protocol);
}
Expand Down
109 changes: 109 additions & 0 deletions src/main/java/com/github/chhsiao/nitm/nitmproxy/tls/CertUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package com.github.chhsiao.nitm.nitmproxy.tls;

import org.bouncycastle.asn1.x500.X500Name;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cert.X509v3CertificateBuilder;
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.PEMWriter;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import org.bouncycastle.operator.ContentSigner;
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;

import java.io.ByteArrayOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.math.BigInteger;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.Provider;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;
import java.time.Instant;
import java.time.Year;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.Date;

public class CertUtil {
private static final Provider PROVIDER = new BouncyCastleProvider();

public static Certificate newCert(String parentCertFile, String keyFile, String host) {
try {
Date before = Date.from(Instant.now());
Date after = Date.from(Year.now().plus(3, ChronoUnit.YEARS).atDay(1).atStartOfDay(ZoneId.systemDefault()).toInstant());

KeyPair keyPair = createKeyPair();

X509CertificateHolder parent = readPemFromFile(parentCertFile);
PEMKeyPair parentPemKeyPair = readPemFromFile(keyFile);
KeyPair parentKeyPair = new JcaPEMKeyConverter()
.setProvider(PROVIDER)
.getKeyPair(parentPemKeyPair);

X509v3CertificateBuilder x509 = new JcaX509v3CertificateBuilder(
parent.getSubject(),
new BigInteger(64, new SecureRandom()),
before,
after,
new X500Name("CN=" + host),
keyPair.getPublic());

ContentSigner signer = new JcaContentSignerBuilder("SHA256WithRSAEncryption")
.build(parentKeyPair.getPrivate());

JcaX509CertificateConverter x509CertificateConverter = new JcaX509CertificateConverter()
.setProvider(PROVIDER);

return new Certificate(
keyPair,
x509CertificateConverter.getCertificate(x509.build(signer)),
x509CertificateConverter.getCertificate(parent));
} catch (Exception e) {
throw new IllegalStateException(e);
}
}

private static KeyPair createKeyPair() throws NoSuchAlgorithmException {
KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA");
keyGen.initialize(1024, new SecureRandom());
return keyGen.generateKeyPair();
}

public static <T> T readPemFromFile(String pemFile) throws IOException {
try (PEMParser pemParser = new PEMParser(new FileReader(pemFile))) {
Object o = pemParser.readObject();

@SuppressWarnings("unchecked")
T t = (T) o;
return t;
}
}

@SuppressWarnings("deprecation")
public static byte[] toPem(Object object) throws IOException {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
try (PEMWriter writer = new PEMWriter(new OutputStreamWriter(outputStream))) {
writer.writeObject(object);
writer.flush();
return outputStream.toByteArray();
}
}

@SuppressWarnings("deprecation")
public static byte[] toPem(Object... objects) throws IOException {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
try (PEMWriter writer = new PEMWriter(new OutputStreamWriter(outputStream))) {
for (Object object : objects) {
writer.writeObject(object);
}
writer.flush();
return outputStream.toByteArray();
}
}
}
Loading

0 comments on commit 73f1d02

Please sign in to comment.