View Javadoc
1   /*
2    *  Licensed to the Apache Software Foundation (ASF) under one
3    *  or more contributor license agreements.  See the NOTICE file
4    *  distributed with this work for additional information
5    *  regarding copyright ownership.  The ASF licenses this file
6    *  to you under the Apache License, Version 2.0 (the
7    *  "License"); you may not use this file except in compliance
8    *  with the License.  You may obtain a copy of the License at
9    *
10   *    http://www.apache.org/licenses/LICENSE-2.0
11   *
12   *  Unless required by applicable law or agreed to in writing,
13   *  software distributed under the License is distributed on an
14   *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   *  KIND, either express or implied.  See the License for the
16   *  specific language governing permissions and limitations
17   *  under the License.
18   *
19   */package org.apache.mina.filter.ssl;
20  
21  import java.io.BufferedReader;
22  import java.io.IOException;
23  import java.io.InputStreamReader;
24  import java.net.InetAddress;
25  import java.net.InetSocketAddress;
26  import java.net.Socket;
27  import java.net.SocketTimeoutException;
28  import java.security.GeneralSecurityException;
29  import java.security.KeyStore;
30  import java.security.Security;
31  
32  import javax.net.ssl.KeyManagerFactory;
33  import javax.net.ssl.SSLContext;
34  import javax.net.ssl.SSLSocketFactory;
35  import javax.net.ssl.TrustManagerFactory;
36  
37  import org.apache.mina.core.filterchain.DefaultIoFilterChainBuilder;
38  import org.apache.mina.core.service.IoHandlerAdapter;
39  import org.apache.mina.core.session.IoSession;
40  import org.apache.mina.filter.codec.ProtocolCodecFilter;
41  import org.apache.mina.filter.codec.textline.TextLineCodecFactory;
42  import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
43  import org.apache.mina.util.AvailablePortFinder;
44  import org.junit.Test;
45  
46  /**
47   * Test a SSL session where the connection is established and closed twice. It should be
48   * processed correctly (Test for DIRMINA-650)
49   *
50   * @author <a href="http://mina.apache.org">Apache MINA Project</a>
51   */
52  public class SslTest {
53      /** A static port used for his test, chosen to avoid collisions */
54      private static final int port = AvailablePortFinder.getNextAvailable(5555);
55  
56      private static Exception clientError = null;
57  
58      private static InetAddress address;
59  
60      private static SSLSocketFactory factory;
61      
62      private static NioSocketAcceptor acceptor;
63  
64      /** A JVM independant KEY_MANAGER_FACTORY algorithm */
65      private static final String KEY_MANAGER_FACTORY_ALGORITHM;
66  
67      static {
68          String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm");
69          if (algorithm == null) {
70              algorithm = KeyManagerFactory.getDefaultAlgorithm();
71          }
72  
73          KEY_MANAGER_FACTORY_ALGORITHM = algorithm;
74      }
75  
76      private static class TestHandler extends IoHandlerAdapter {
77          public void messageReceived(IoSession session, Object message) throws Exception {
78              String line = (String) message;
79  
80              if (line.startsWith("hello")) {
81                  //System.out.println("Server got: 'hello', waiting for 'send'");
82                  Thread.sleep(1500);
83              } else if (line.startsWith("send")) {
84                  //System.out.println("Server got: 'send', sending 'data'");
85                  StringBuilder sb = new StringBuilder();
86                  
87                  for ( int i = 0; i < 10000; i++) {
88                      sb.append('A');
89                  }
90                      
91                  session.write(sb.toString());
92                  session.closeOnFlush();
93              }
94          }
95      }
96  
97      /**
98       * Starts a Server with the SSL Filter and a simple text line 
99       * protocol codec filter
100      */
101     private static void startServer() throws Exception {
102         acceptor = new NioSocketAcceptor();
103 
104         acceptor.setReuseAddress(true);
105         DefaultIoFilterChainBuilder filters = acceptor.getFilterChain();
106 
107         // Inject the SSL filter
108         SslFilter sslFilter = new SslFilter(createSSLContext());
109         filters.addLast("sslFilter", sslFilter);
110         sslFilter.setNeedClientAuth(true);
111 
112         // Inject the TestLine codec filter
113         filters.addLast("text", new ProtocolCodecFilter(new TextLineCodecFactory()));
114 
115         acceptor.setHandler(new TestHandler());
116         acceptor.bind(new InetSocketAddress(port));
117     }
118     
119     private static void stopServer() {
120         acceptor.dispose();
121     }
122 
123     /**
124      * Starts a client which will connect twice using SSL
125      */
126     private static void startClient() throws Exception {
127         address = InetAddress.getByName("localhost");
128 
129         SSLContext context = createSSLContext();
130         factory = context.getSocketFactory();
131 
132         connectAndSend();
133 
134         // This one will throw a SocketTimeoutException if DIRMINA-650 is not fixed
135         connectAndSend();
136     }
137 
138     private static void connectAndSend() throws Exception {
139         Socket parent = new Socket(address, port);
140         Socket socket = factory.createSocket(parent, address.getCanonicalHostName(), port, false);
141 
142         //System.out.println("Client sending: hello");
143         socket.getOutputStream().write("hello                      \n".getBytes());
144         socket.getOutputStream().flush();
145         socket.setSoTimeout(1000000);
146 
147         //System.out.println("Client sending: send");
148         socket.getOutputStream().write("send\n".getBytes());
149         socket.getOutputStream().flush();
150 
151         BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
152         String line = in.readLine();
153         //System.out.println("Client got: " + line);
154         socket.close();
155 
156     }
157 
158     private static SSLContext createSSLContext() throws IOException, GeneralSecurityException {
159         char[] passphrase = "password".toCharArray();
160 
161         SSLContext ctx = SSLContext.getInstance("TLS");
162         KeyManagerFactory kmf = KeyManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
163         TrustManagerFactory tmf = TrustManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
164 
165         KeyStore ks = KeyStore.getInstance("JKS");
166         KeyStore ts = KeyStore.getInstance("JKS");
167 
168         ks.load(SslTest.class.getResourceAsStream("keystore.sslTest"), passphrase);
169         ts.load(SslTest.class.getResourceAsStream("truststore.sslTest"), passphrase);
170 
171         kmf.init(ks, passphrase);
172         tmf.init(ts);
173         ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
174 
175         return ctx;
176     }
177 
178     @Test
179     public void testSSL() throws Exception {
180         try {
181             startServer();
182     
183             Thread t = new Thread() {
184                 public void run() {
185                     try {
186                         startClient();
187                     } catch (Exception e) {
188                         clientError = e;
189                     }
190                 }
191             };
192             t.start();
193             t.join();
194             
195             if (clientError != null) {
196                 throw clientError;
197             }
198         } finally {
199             stopServer();
200         }
201     }
202     
203     
204     @Test
205     public void unsecureClientTryToConnectoToSecureServer() throws Exception {
206         try {
207             startServer(); // Start Server with SSLFilter
208     
209             //Now start a client without any SSL
210             Thread t = new Thread() {
211                 @Override
212                 public void run() {
213                     try {
214                         address = InetAddress.getByName("localhost");
215     
216                         Socket socket = new Socket(address, port);
217                         socket.setSoTimeout(10000);
218     
219                         String response = null;
220     
221                         while (response == null) {
222                             try {
223                                 System.out.println(socket.isConnected());
224                                 // System.out.println("Client sending: hello");
225                                 socket.getOutputStream().write("hello                      \n".getBytes());
226                                 socket.getOutputStream().flush();
227                                 socket.setSoTimeout(1000);
228     
229                                 // System.out.println("Client sending: send");
230                                 socket.getOutputStream().write("send\n".getBytes());
231                                 socket.getOutputStream().flush();
232     
233                                 BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
234                                 String line = "";
235                                 
236                                 while ((line = in.readLine()) != null) {
237                                     response = response + line;
238                                 }
239                             } catch (SocketTimeoutException timeout) {
240                                 // donothing
241                                 timeout.printStackTrace();
242                             }
243                         }
244                         
245                         if (response.contains("AAAAAAA")){
246                             throw new IllegalStateException("getting response:" + response);
247                         }
248                         
249                         // System.out.println("Client got: " + line);
250                         socket.close();
251                     } catch (Exception e) {
252                         clientError = e;
253                     }
254                 }
255             };
256             
257             t.start();
258             t.join();
259             
260             if (clientError != null) {
261                 throw clientError;
262             }
263         } finally {
264             stopServer();
265         }
266     }
267 }