blob: e217d4f8978a5eeeb145e7e53ad1731c46591fcb [file] [log] [blame]
//
// ========================================================================
// Copyright (c) 1995-2016 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.client.ssl;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.toolchain.test.TestTracker;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.junit.Assert;
import org.junit.Rule;
public abstract class SslBytesTest
{
@Rule
public TestTracker tracker = new TestTracker();
protected final Logger logger = Log.getLogger(getClass());
public static class TLSRecord
{
private final SslBytesServerTest.TLSRecord.Type type;
private final byte[] bytes;
public TLSRecord(SslBytesServerTest.TLSRecord.Type type, byte[] bytes)
{
this.type = type;
this.bytes = bytes;
}
public SslBytesServerTest.TLSRecord.Type getType()
{
return type;
}
public byte[] getBytes()
{
return bytes;
}
@Override
public String toString()
{
return "TLSRecord [" + type + "] " + bytes.length + " bytes";
}
public enum Type
{
CHANGE_CIPHER_SPEC(20), ALERT(21), HANDSHAKE(22), APPLICATION(23);
private int code;
private Type(int code)
{
this.code = code;
SslBytesServerTest.TLSRecord.Type.Mapper.codes.put(this.code, this);
}
public static SslBytesServerTest.TLSRecord.Type from(int code)
{
SslBytesServerTest.TLSRecord.Type result = SslBytesServerTest.TLSRecord.Type.Mapper.codes.get(code);
if (result == null)
throw new IllegalArgumentException("Invalid TLSRecord.Type " + code);
return result;
}
private static class Mapper
{
private static final Map<Integer, SslBytesServerTest.TLSRecord.Type> codes = new HashMap<>();
}
}
}
public class SimpleProxy implements Runnable
{
private final CountDownLatch latch = new CountDownLatch(1);
private final ExecutorService threadPool;
private final int proxyPort;
private final String serverHost;
private final int serverPort;
private volatile ServerSocket serverSocket;
private volatile Socket server;
private volatile Socket client;
public SimpleProxy(ExecutorService threadPool, String serverHost, int serverPort)
{
this(threadPool, 0, serverHost, serverPort);
}
public SimpleProxy(ExecutorService threadPool, int proxyPort, String serverHost, int serverPort)
{
this.threadPool = threadPool;
this.proxyPort = proxyPort;
this.serverHost = serverHost;
this.serverPort = serverPort;
}
public void start() throws Exception
{
serverSocket = new ServerSocket(proxyPort);
Thread acceptor = new Thread(this);
acceptor.start();
server = new Socket(serverHost, serverPort);
}
public void stop() throws Exception
{
server.close();
if (client != null) // some tests only run on linux, those won't create a client on other OS
client.close();
serverSocket.close();
}
public void run()
{
try
{
client = serverSocket.accept();
latch.countDown();
}
catch (IOException x)
{
x.printStackTrace();
}
}
public int getPort()
{
return serverSocket.getLocalPort();
}
public TLSRecord readFromClient() throws IOException
{
TLSRecord record = read(client);
logger.debug("C --> P {}", record);
return record;
}
private TLSRecord read(Socket socket) throws IOException
{
InputStream input = socket.getInputStream();
int first = -2;
while (true)
{
try
{
socket.setSoTimeout(500);
first = input.read();
break;
}
catch (SocketTimeoutException x)
{
if (Thread.currentThread().isInterrupted())
break;
}
}
if (first == -2)
throw new InterruptedIOException();
else if (first == -1)
return null;
if (first >= 0x80)
{
// SSLv2 Record
int hiLength = first & 0x3F;
int loLength = input.read();
int length = (hiLength << 8) + loLength;
byte[] bytes = new byte[2 + length];
bytes[0] = (byte)first;
bytes[1] = (byte)loLength;
return read(TLSRecord.Type.HANDSHAKE, input, bytes, 2, length);
}
else
{
// TLS Record
int major = input.read();
int minor = input.read();
int hiLength = input.read();
int loLength = input.read();
int length = (hiLength << 8) + loLength;
byte[] bytes = new byte[1 + 2 + 2 + length];
bytes[0] = (byte)first;
bytes[1] = (byte)major;
bytes[2] = (byte)minor;
bytes[3] = (byte)hiLength;
bytes[4] = (byte)loLength;
return read(TLSRecord.Type.from(first), input, bytes, 5, length);
}
}
private TLSRecord read(SslBytesServerTest.TLSRecord.Type type, InputStream input, byte[] bytes, int offset, int length) throws IOException
{
while (length > 0)
{
int read = input.read(bytes, offset, length);
if (read < 0)
throw new EOFException();
offset += read;
length -= read;
}
return new TLSRecord(type, bytes);
}
public void flushToServer(TLSRecord record) throws Exception
{
flushToServer(record, 100);
}
public void flushToServer(TLSRecord record, long sleep) throws Exception
{
if (record == null)
{
server.shutdownOutput();
if (client.isOutputShutdown())
{
client.close();
server.close();
}
}
else
{
flush(sleep, server, record.getBytes());
}
}
public void flushToServer(long sleep, byte... bytes) throws Exception
{
flush(sleep, server, bytes);
}
private void flush(long sleep, Socket socket, byte... bytes) throws Exception
{
OutputStream output = socket.getOutputStream();
output.write(bytes);
output.flush();
if (sleep > 0)
TimeUnit.MILLISECONDS.sleep(sleep);
}
public TLSRecord readFromServer() throws IOException
{
TLSRecord record = read(server);
logger.debug("P <-- S {}", record);
return record;
}
public void flushToClient(TLSRecord record) throws Exception
{
if (record == null)
{
client.shutdownOutput();
if (server.isOutputShutdown())
{
server.close();
client.close();
}
}
else
{
flush(0, client, record.getBytes());
}
}
public SslBytesServerTest.SimpleProxy.AutomaticFlow startAutomaticFlow() throws InterruptedException
{
final CountDownLatch startLatch = new CountDownLatch(2);
final CountDownLatch stopLatch = new CountDownLatch(2);
Future<Object> clientToServer = threadPool.submit(new Callable<Object>()
{
public Object call() throws Exception
{
startLatch.countDown();
logger.debug("Automatic flow C --> S started");
try
{
while (true)
{
flushToServer(readFromClient(), 0);
}
}
catch (InterruptedIOException x)
{
return null;
}
finally
{
stopLatch.countDown();
logger.debug("Automatic flow C --> S finished");
}
}
});
Future<Object> serverToClient = threadPool.submit(new Callable<Object>()
{
public Object call() throws Exception
{
startLatch.countDown();
logger.debug("Automatic flow C <-- S started");
try
{
while (true)
{
flushToClient(readFromServer());
}
}
catch (InterruptedIOException x)
{
return null;
}
finally
{
stopLatch.countDown();
logger.debug("Automatic flow C <-- S finished");
}
}
});
Assert.assertTrue(startLatch.await(5, TimeUnit.SECONDS));
return new SslBytesServerTest.SimpleProxy.AutomaticFlow(stopLatch, clientToServer, serverToClient);
}
public boolean awaitClient(int time, TimeUnit unit) throws InterruptedException
{
return latch.await(time, unit);
}
public void sendRSTToServer() throws IOException
{
// Calling setSoLinger(true, 0) causes close()
// to send a RST instead of a FIN, causing an
// exception to be thrown on the other end
server.setSoLinger(true, 0);
server.close();
}
public class AutomaticFlow
{
private final CountDownLatch stopLatch;
private final Future<Object> clientToServer;
private final Future<Object> serverToClient;
public AutomaticFlow(CountDownLatch stopLatch, Future<Object> clientToServer, Future<Object> serverToClient)
{
this.stopLatch = stopLatch;
this.clientToServer = clientToServer;
this.serverToClient = serverToClient;
}
public boolean stop(long time, TimeUnit unit) throws InterruptedException
{
clientToServer.cancel(true);
serverToClient.cancel(true);
return stopLatch.await(time, unit);
}
}
}
}