001/*
002 *  Licensed to the Apache Software Foundation (ASF) under one
003 *  or more contributor license agreements.  See the NOTICE file
004 *  distributed with this work for additional information
005 *  regarding copyright ownership.  The ASF licenses this file
006 *  to you under the Apache License, Version 2.0 (the
007 *  "License"); you may not use this file except in compliance
008 *  with the License.  You may obtain a copy of the License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 *  Unless required by applicable law or agreed to in writing,
013 *  software distributed under the License is distributed on an
014 *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 *  KIND, either express or implied.  See the License for the
016 *  specific language governing permissions and limitations
017 *  under the License.
018 *
019 */
020package org.apache.mina.example.echoserver.ssl;
021
022import static org.junit.Assert.assertEquals;
023import static org.junit.Assert.assertTrue;
024
025import java.net.InetSocketAddress;
026import java.net.Socket;
027import java.nio.charset.Charset;
028import java.security.cert.CertificateException;
029import java.util.ArrayList;
030import java.util.List;
031
032import javax.net.ssl.SSLContext;
033import javax.net.ssl.SSLSocket;
034import javax.net.ssl.TrustManager;
035import javax.net.ssl.X509TrustManager;
036
037import org.apache.mina.core.service.IoHandlerAdapter;
038import org.apache.mina.core.session.IoSession;
039import org.apache.mina.filter.codec.ProtocolCodecFilter;
040import org.apache.mina.filter.codec.textline.TextLineCodecFactory;
041import org.apache.mina.filter.ssl.SslFilter;
042import org.apache.mina.transport.socket.SocketAcceptor;
043import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
044import org.junit.After;
045import org.junit.Before;
046import org.junit.Test;
047
048/**
049 * TODO Add documentation
050 * 
051 * @author <a href="http://mina.apache.org">Apache MINA Project</a>
052 */
053public class SslFilterTest {
054
055    private int port;
056    private SocketAcceptor acceptor;
057
058    @Before
059    public void setUp() throws Exception {
060        acceptor = new NioSocketAcceptor();
061    }
062
063    @After
064    public void tearDown() throws Exception {
065        acceptor.setCloseOnDeactivation(true);
066        acceptor.dispose();
067    }
068
069    @Test
070    public void testMessageSentIsCalled() throws Exception {
071        testMessageSentIsCalled(false);
072    }
073
074    @Test
075    public void testMessageSentIsCalled_With_SSL() throws Exception {
076        testMessageSentIsCalled(true);
077    }
078
079    private void testMessageSentIsCalled(boolean useSSL) throws Exception {
080        // Workaround to fix TLS issue : http://java.sun.com/javase/javaseforbusiness/docs/TLSReadme.html
081        java.lang.System.setProperty( "sun.security.ssl.allowUnsafeRenegotiation", "true" );
082
083        SslFilter sslFilter = null;
084        if (useSSL) {
085            sslFilter = new SslFilter(BogusSslContextFactory.getInstance(true));
086            acceptor.getFilterChain().addLast("sslFilter", sslFilter);
087        }
088        acceptor.getFilterChain().addLast(
089                "codec",
090                new ProtocolCodecFilter(new TextLineCodecFactory(Charset
091                        .forName("UTF-8"))));
092
093        EchoHandler handler = new EchoHandler();
094        acceptor.setHandler(handler);
095        acceptor.bind(new InetSocketAddress(0));
096        port = acceptor.getLocalAddress().getPort();
097        //System.out.println("MINA server started.");
098
099        Socket socket = getClientSocket(useSSL);
100        int bytesSent = 0;
101        bytesSent += writeMessage(socket, "test-1\n");
102
103        if (useSSL) {
104            // Test renegotiation
105            SSLSocket ss = (SSLSocket) socket;
106            //ss.getSession().invalidate();
107            ss.startHandshake();
108        }
109
110        bytesSent += writeMessage(socket, "test-2\n");
111
112        int[] response = new int[bytesSent];
113        for (int i = 0; i < response.length; i++) {
114            response[i] = socket.getInputStream().read();
115        }
116
117        if (useSSL) {
118            // Read SSL close notify.
119            while (socket.getInputStream().read() >= 0) {
120                continue;
121            }
122        }
123
124        socket.close();
125        while (acceptor.getManagedSessions().size() != 0) {
126            Thread.sleep(100);
127        }
128
129        //System.out.println("handler: " + handler.sentMessages);
130        assertEquals("handler should have sent 2 messages:", 2,
131                handler.sentMessages.size());
132        assertTrue(handler.sentMessages.contains("test-1"));
133        assertTrue(handler.sentMessages.contains("test-2"));
134    }
135
136    private int writeMessage(Socket socket, String message) throws Exception {
137        byte request[] = message.getBytes("UTF-8");
138        socket.getOutputStream().write(request);
139        return request.length;
140    }
141
142    private Socket getClientSocket(boolean ssl) throws Exception {
143        if (ssl) {
144            SSLContext ctx = SSLContext.getInstance("TLS");
145            ctx.init(null, trustManagers, null);
146            return ctx.getSocketFactory().createSocket("localhost", port);
147        }
148        return new Socket("localhost", port);
149    }
150
151    private static class EchoHandler extends IoHandlerAdapter {
152
153        List<String> sentMessages = new ArrayList<String>();
154
155        @Override
156        public void exceptionCaught(IoSession session, Throwable cause)
157                throws Exception {
158            //cause.printStackTrace();
159        }
160
161        @Override
162        public void messageReceived(IoSession session, Object message)
163                throws Exception {
164            session.write(message);
165        }
166
167        @Override
168        public void messageSent(IoSession session, Object message)
169                throws Exception {
170            sentMessages.add(message.toString());
171
172            if (sentMessages.size() >= 2) {
173                session.close(true);
174            }
175        }
176    }
177
178    TrustManager[] trustManagers = new TrustManager[] { new TrustAnyone() };
179
180    private static class TrustAnyone implements X509TrustManager {
181        public void checkClientTrusted(
182                java.security.cert.X509Certificate[] x509Certificates, String s)
183                throws CertificateException {
184        }
185
186        public void checkServerTrusted(
187                java.security.cert.X509Certificate[] x509Certificates, String s)
188                throws CertificateException {
189        }
190
191        public java.security.cert.X509Certificate[] getAcceptedIssuers() {
192            return new java.security.cert.X509Certificate[0];
193        }
194    }
195
196}