1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.mina.filter.ssl;
20
21 import java.io.BufferedReader;
22 import java.io.IOException;
23 import java.io.InputStreamReader;
24 import java.net.InetAddress;
25 import java.net.InetSocketAddress;
26 import java.net.Socket;
27 import java.net.SocketTimeoutException;
28 import java.security.GeneralSecurityException;
29 import java.security.KeyStore;
30 import java.security.Security;
31
32 import javax.net.ssl.KeyManagerFactory;
33 import javax.net.ssl.SSLContext;
34 import javax.net.ssl.SSLSocketFactory;
35 import javax.net.ssl.TrustManagerFactory;
36
37 import org.apache.mina.core.filterchain.DefaultIoFilterChainBuilder;
38 import org.apache.mina.core.service.IoHandlerAdapter;
39 import org.apache.mina.core.session.IoSession;
40 import org.apache.mina.filter.codec.ProtocolCodecFilter;
41 import org.apache.mina.filter.codec.textline.TextLineCodecFactory;
42 import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
43 import org.apache.mina.util.AvailablePortFinder;
44 import org.junit.Test;
45
46
47
48
49
50
51
52 public class SslTest {
53
54 private static final int port = AvailablePortFinder.getNextAvailable(5555);
55
56 private static Exception clientError = null;
57
58 private static InetAddress address;
59
60 private static SSLSocketFactory factory;
61
62 private static NioSocketAcceptor acceptor;
63
64
65 private static final String KEY_MANAGER_FACTORY_ALGORITHM;
66
67 static {
68 String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm");
69 if (algorithm == null) {
70 algorithm = KeyManagerFactory.getDefaultAlgorithm();
71 }
72
73 KEY_MANAGER_FACTORY_ALGORITHM = algorithm;
74 }
75
76 private static class TestHandler extends IoHandlerAdapter {
77 public void messageReceived(IoSession session, Object message) throws Exception {
78 String line = (String) message;
79
80 if (line.startsWith("hello")) {
81
82 Thread.sleep(1500);
83 } else if (line.startsWith("send")) {
84
85 StringBuilder sb = new StringBuilder();
86
87 for ( int i = 0; i < 10000; i++) {
88 sb.append('A');
89 }
90
91 session.write(sb.toString());
92 session.closeOnFlush();
93 }
94 }
95 }
96
97
98
99
100
101 private static void startServer() throws Exception {
102 acceptor = new NioSocketAcceptor();
103
104 acceptor.setReuseAddress(true);
105 DefaultIoFilterChainBuilder filters = acceptor.getFilterChain();
106
107
108 SslFilter sslFilter = new SslFilter(createSSLContext());
109 filters.addLast("sslFilter", sslFilter);
110 sslFilter.setNeedClientAuth(true);
111
112
113 filters.addLast("text", new ProtocolCodecFilter(new TextLineCodecFactory()));
114
115 acceptor.setHandler(new TestHandler());
116 acceptor.bind(new InetSocketAddress(port));
117 }
118
119 private static void stopServer() {
120 acceptor.dispose();
121 }
122
123
124
125
126 private static void startClient() throws Exception {
127 address = InetAddress.getByName("localhost");
128
129 SSLContext context = createSSLContext();
130 factory = context.getSocketFactory();
131
132 connectAndSend();
133
134
135 connectAndSend();
136 }
137
138 private static void connectAndSend() throws Exception {
139 Socket parent = new Socket(address, port);
140 Socket socket = factory.createSocket(parent, address.getCanonicalHostName(), port, false);
141
142
143 socket.getOutputStream().write("hello \n".getBytes());
144 socket.getOutputStream().flush();
145 socket.setSoTimeout(1000000);
146
147
148 socket.getOutputStream().write("send\n".getBytes());
149 socket.getOutputStream().flush();
150
151 BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
152 String line = in.readLine();
153
154 socket.close();
155
156 }
157
158 private static SSLContext createSSLContext() throws IOException, GeneralSecurityException {
159 char[] passphrase = "password".toCharArray();
160
161 SSLContext ctx = SSLContext.getInstance("TLS");
162 KeyManagerFactory kmf = KeyManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
163 TrustManagerFactory tmf = TrustManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
164
165 KeyStore ks = KeyStore.getInstance("JKS");
166 KeyStore ts = KeyStore.getInstance("JKS");
167
168 ks.load(SslTest.class.getResourceAsStream("keystore.sslTest"), passphrase);
169 ts.load(SslTest.class.getResourceAsStream("truststore.sslTest"), passphrase);
170
171 kmf.init(ks, passphrase);
172 tmf.init(ts);
173 ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
174
175 return ctx;
176 }
177
178 @Test
179 public void testSSL() throws Exception {
180 try {
181 startServer();
182
183 Thread t = new Thread() {
184 public void run() {
185 try {
186 startClient();
187 } catch (Exception e) {
188 clientError = e;
189 }
190 }
191 };
192 t.start();
193 t.join();
194
195 if (clientError != null) {
196 throw clientError;
197 }
198 } finally {
199 stopServer();
200 }
201 }
202
203
204 @Test
205 public void unsecureClientTryToConnectoToSecureServer() throws Exception {
206 try {
207 startServer();
208
209
210 Thread t = new Thread() {
211 @Override
212 public void run() {
213 try {
214 address = InetAddress.getByName("localhost");
215
216 Socket socket = new Socket(address, port);
217 socket.setSoTimeout(10000);
218
219 String response = null;
220
221 while (response == null) {
222 try {
223 System.out.println(socket.isConnected());
224
225 socket.getOutputStream().write("hello \n".getBytes());
226 socket.getOutputStream().flush();
227 socket.setSoTimeout(1000);
228
229
230 socket.getOutputStream().write("send\n".getBytes());
231 socket.getOutputStream().flush();
232
233 BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
234 String line = "";
235
236 while ((line = in.readLine()) != null) {
237 response = response + line;
238 }
239 } catch (SocketTimeoutException timeout) {
240
241 timeout.printStackTrace();
242 }
243 }
244
245 if (response.contains("AAAAAAA")){
246 throw new IllegalStateException("getting response:" + response);
247 }
248
249
250 socket.close();
251 } catch (Exception e) {
252 clientError = e;
253 }
254 }
255 };
256
257 t.start();
258 t.join();
259
260 if (clientError != null) {
261 throw clientError;
262 }
263 } finally {
264 stopServer();
265 }
266 }
267 }