1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.ssl;
21
22 import java.net.InetSocketAddress;
23 import java.nio.ByteBuffer;
24 import java.util.Queue;
25 import java.util.concurrent.ConcurrentLinkedQueue;
26
27 import javax.net.ssl.SSLContext;
28 import javax.net.ssl.SSLEngine;
29 import javax.net.ssl.SSLEngineResult;
30 import javax.net.ssl.SSLException;
31 import javax.net.ssl.SSLHandshakeException;
32
33 import org.apache.mina.core.buffer.IoBuffer;
34 import org.apache.mina.core.filterchain.IoFilterEvent;
35 import org.apache.mina.core.filterchain.IoFilter.NextFilter;
36 import org.apache.mina.core.future.DefaultWriteFuture;
37 import org.apache.mina.core.future.WriteFuture;
38 import org.apache.mina.core.session.IoEventType;
39 import org.apache.mina.core.session.IoSession;
40 import org.apache.mina.core.write.DefaultWriteRequest;
41 import org.apache.mina.core.write.WriteRequest;
42 import org.apache.mina.util.CircularQueue;
43 import org.slf4j.Logger;
44 import org.slf4j.LoggerFactory;
45
46
47
48
49
50
51
52
53
54
55
56
57 class SslHandler {
58
59 private final Logger logger = LoggerFactory.getLogger(getClass());
60 private final SslFilter parent;
61 private final SSLContext ctx;
62 private final IoSession session;
63 private final Queue<IoFilterEvent> preHandshakeEventQueue = new CircularQueue<IoFilterEvent>();
64 private final Queue<IoFilterEvent> filterWriteEventQueue = new ConcurrentLinkedQueue<IoFilterEvent>();
65 private final Queue<IoFilterEvent> messageReceivedEventQueue = new ConcurrentLinkedQueue<IoFilterEvent>();
66 private SSLEngine sslEngine;
67
68
69
70
71 private IoBuffer inNetBuffer;
72
73
74
75
76 private IoBuffer outNetBuffer;
77
78
79
80
81 private IoBuffer appBuffer;
82
83
84
85
86 private final IoBuffer emptyBuffer = IoBuffer.allocate(0);
87
88 private SSLEngineResult.HandshakeStatus handshakeStatus;
89 private boolean initialHandshakeComplete;
90 private boolean handshakeComplete;
91 private boolean writingEncryptedData;
92
93
94
95
96
97
98
99 public SslHandler(SslFilter parent, SSLContext sslc, IoSession session)
100 throws SSLException {
101 this.parent = parent;
102 this.session = session;
103 ctx = sslc;
104 init();
105 }
106
107 public void init() throws SSLException {
108 if (sslEngine != null) {
109 return;
110 }
111
112 InetSocketAddress peer = (InetSocketAddress) session
113 .getAttribute(SslFilter.PEER_ADDRESS);
114 if (peer == null) {
115 sslEngine = ctx.createSSLEngine();
116 } else {
117 sslEngine = ctx.createSSLEngine(peer.getHostName(), peer.getPort());
118 }
119 sslEngine.setUseClientMode(parent.isUseClientMode());
120
121 if (parent.isWantClientAuth()) {
122 sslEngine.setWantClientAuth(true);
123 }
124
125 if (parent.isNeedClientAuth()) {
126 sslEngine.setNeedClientAuth(true);
127 }
128
129 if (parent.getEnabledCipherSuites() != null) {
130 sslEngine.setEnabledCipherSuites(parent.getEnabledCipherSuites());
131 }
132
133 if (parent.getEnabledProtocols() != null) {
134 sslEngine.setEnabledProtocols(parent.getEnabledProtocols());
135 }
136
137 sslEngine.beginHandshake();
138 handshakeStatus = sslEngine.getHandshakeStatus();
139
140 handshakeComplete = false;
141 initialHandshakeComplete = false;
142 writingEncryptedData = false;
143 }
144
145
146
147
148 public void destroy() {
149 if (sslEngine == null) {
150 return;
151 }
152
153
154 try {
155 sslEngine.closeInbound();
156 } catch (SSLException e) {
157 logger.debug(
158 "Unexpected exception from SSLEngine.closeInbound().", e);
159 }
160
161
162 if (outNetBuffer != null) {
163 outNetBuffer.capacity(sslEngine.getSession().getPacketBufferSize());
164 } else {
165 createOutNetBuffer(0);
166 }
167 try {
168 do {
169 outNetBuffer.clear();
170 } while (sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf()).bytesProduced() > 0);
171 } catch (SSLException e) {
172
173 } finally {
174 destroyOutNetBuffer();
175 }
176
177 sslEngine.closeOutbound();
178 sslEngine = null;
179
180 preHandshakeEventQueue.clear();
181 }
182
183 private void destroyOutNetBuffer() {
184 outNetBuffer.free();
185 outNetBuffer = null;
186 }
187
188 public SslFilter getParent() {
189 return parent;
190 }
191
192 public IoSession getSession() {
193 return session;
194 }
195
196
197
198
199 public boolean isWritingEncryptedData() {
200 return writingEncryptedData;
201 }
202
203
204
205
206 public boolean isHandshakeComplete() {
207 return handshakeComplete;
208 }
209
210 public boolean isInboundDone() {
211 return sslEngine == null || sslEngine.isInboundDone();
212 }
213
214 public boolean isOutboundDone() {
215 return sslEngine == null || sslEngine.isOutboundDone();
216 }
217
218
219
220
221 public boolean needToCompleteHandshake() {
222 return handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP && !isInboundDone();
223 }
224
225 public void schedulePreHandshakeWriteRequest(NextFilter nextFilter,
226 WriteRequest writeRequest) {
227 preHandshakeEventQueue.add(new IoFilterEvent(nextFilter,
228 IoEventType.WRITE, session, writeRequest));
229 }
230
231 public void flushPreHandshakeEvents() throws SSLException {
232 IoFilterEvent scheduledWrite;
233
234 while ((scheduledWrite = preHandshakeEventQueue.poll()) != null) {
235 parent.filterWrite(scheduledWrite.getNextFilter(), session,
236 (WriteRequest) scheduledWrite.getParameter());
237 }
238 }
239
240 public void scheduleFilterWrite(NextFilter nextFilter, WriteRequest writeRequest) {
241 filterWriteEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.WRITE, session, writeRequest));
242 }
243
244 public void scheduleMessageReceived(NextFilter nextFilter, Object message) {
245 messageReceivedEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.MESSAGE_RECEIVED, session, message));
246 }
247
248 public void flushScheduledEvents() {
249
250 if (Thread.holdsLock(this)) {
251 return;
252 }
253
254 IoFilterEvent e;
255
256
257
258 synchronized (this) {
259 while ((e = filterWriteEventQueue.poll()) != null) {
260 e.getNextFilter().filterWrite(session, (WriteRequest) e.getParameter());
261 }
262 }
263
264 while ((e = messageReceivedEventQueue.poll()) != null) {
265 e.getNextFilter().messageReceived(session, e.getParameter());
266 }
267 }
268
269
270
271
272
273
274
275
276
277
278 public void messageReceived(NextFilter nextFilter, ByteBuffer buf) throws SSLException {
279
280 if (inNetBuffer == null) {
281 inNetBuffer = IoBuffer.allocate(buf.remaining()).setAutoExpand(true);
282 }
283
284 inNetBuffer.put(buf);
285 if (!handshakeComplete) {
286 handshake(nextFilter);
287 } else {
288 decrypt(nextFilter);
289 }
290
291 if (isInboundDone()) {
292
293 int inNetBufferPosition = inNetBuffer == null? 0 : inNetBuffer.position();
294 buf.position(buf.position() - inNetBufferPosition);
295 inNetBuffer = null;
296 }
297 }
298
299
300
301
302
303
304 public IoBuffer fetchAppBuffer() {
305 IoBuffer appBuffer = this.appBuffer.flip();
306 this.appBuffer = null;
307 return appBuffer;
308 }
309
310
311
312
313
314
315 public IoBuffer fetchOutNetBuffer() {
316 IoBuffer answer = outNetBuffer;
317 if (answer == null) {
318 return emptyBuffer;
319 }
320
321 outNetBuffer = null;
322 return answer.shrink();
323 }
324
325
326
327
328
329
330
331 public void encrypt(ByteBuffer src) throws SSLException {
332 if (!handshakeComplete) {
333 throw new IllegalStateException();
334 }
335
336 if (!src.hasRemaining()) {
337 if (outNetBuffer == null) {
338 outNetBuffer = emptyBuffer;
339 }
340 return;
341 }
342
343 createOutNetBuffer(src.remaining());
344
345
346 while (src.hasRemaining()) {
347
348 SSLEngineResult result = sslEngine.wrap(src, outNetBuffer.buf());
349 if (result.getStatus() == SSLEngineResult.Status.OK) {
350 if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
351 doTasks();
352 }
353 } else if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
354 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
355 outNetBuffer.limit(outNetBuffer.capacity());
356 } else {
357 throw new SSLException("SSLEngine error during encrypt: "
358 + result.getStatus() + " src: " + src
359 + "outNetBuffer: " + outNetBuffer);
360 }
361 }
362
363 outNetBuffer.flip();
364 }
365
366
367
368
369
370
371
372
373 public boolean closeOutbound() throws SSLException {
374 if (sslEngine == null || sslEngine.isOutboundDone()) {
375 return false;
376 }
377
378 sslEngine.closeOutbound();
379
380 createOutNetBuffer(0);
381 SSLEngineResult result;
382 for (;;) {
383 result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
384 if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
385 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
386 outNetBuffer.limit(outNetBuffer.capacity());
387 } else {
388 break;
389 }
390 }
391
392 if (result.getStatus() != SSLEngineResult.Status.CLOSED) {
393 throw new SSLException("Improper close state: " + result);
394 }
395 outNetBuffer.flip();
396 return true;
397 }
398
399
400
401
402
403
404 private void decrypt(NextFilter nextFilter) throws SSLException {
405
406 if (!handshakeComplete) {
407 throw new IllegalStateException();
408 }
409
410 unwrap(nextFilter);
411 }
412
413
414
415
416
417 private void checkStatus(SSLEngineResult res)
418 throws SSLException {
419
420 SSLEngineResult.Status status = res.getStatus();
421
422
423
424
425
426
427
428
429
430 if (status != SSLEngineResult.Status.OK
431 && status != SSLEngineResult.Status.CLOSED
432 && status != SSLEngineResult.Status.BUFFER_UNDERFLOW) {
433 throw new SSLException("SSLEngine error during decrypt: " + status
434 + " inNetBuffer: " + inNetBuffer + "appBuffer: "
435 + appBuffer);
436 }
437 }
438
439
440
441
442 public void handshake(NextFilter nextFilter) throws SSLException {
443 for (; ;) {
444 if (handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED) {
445 session.setAttribute(
446 SslFilter.SSL_SESSION, sslEngine.getSession());
447 handshakeComplete = true;
448 if (!initialHandshakeComplete
449 && session.containsAttribute(SslFilter.USE_NOTIFICATION)) {
450
451
452 initialHandshakeComplete = true;
453 scheduleMessageReceived(nextFilter,
454 SslFilter.SESSION_SECURED);
455 }
456 break;
457 } else if (handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) {
458 handshakeStatus = doTasks();
459 } else if (handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
460
461 SSLEngineResult.Status status = unwrapHandshake(nextFilter);
462 if (status == SSLEngineResult.Status.BUFFER_UNDERFLOW &&
463 handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED ||
464 isInboundDone()) {
465
466 break;
467 }
468 } else if (handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
469
470
471 if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
472 break;
473 }
474
475 SSLEngineResult result;
476 createOutNetBuffer(0);
477 for (;;) {
478 result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
479 if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
480 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
481 outNetBuffer.limit(outNetBuffer.capacity());
482 } else {
483 break;
484 }
485 }
486
487 outNetBuffer.flip();
488 handshakeStatus = result.getHandshakeStatus();
489 writeNetBuffer(nextFilter);
490 } else {
491 throw new IllegalStateException("Invalid Handshaking State"
492 + handshakeStatus);
493 }
494 }
495 }
496
497 private void createOutNetBuffer(int expectedRemaining) {
498
499
500 int capacity = Math.max(
501 expectedRemaining,
502 sslEngine.getSession().getPacketBufferSize());
503
504 if (outNetBuffer != null) {
505 outNetBuffer.capacity(capacity);
506 } else {
507 outNetBuffer = IoBuffer.allocate(capacity).minimumCapacity(0);
508 }
509 }
510
511 public WriteFuture writeNetBuffer(NextFilter nextFilter)
512 throws SSLException {
513
514 if (outNetBuffer == null || !outNetBuffer.hasRemaining()) {
515
516 return null;
517 }
518
519
520
521 writingEncryptedData = true;
522
523
524 WriteFuture writeFuture = null;
525
526 try {
527 IoBuffer writeBuffer = fetchOutNetBuffer();
528 writeFuture = new DefaultWriteFuture(session);
529 parent.filterWrite(nextFilter, session, new DefaultWriteRequest(
530 writeBuffer, writeFuture));
531
532
533 while (needToCompleteHandshake()) {
534 try {
535 handshake(nextFilter);
536 } catch (SSLException ssle) {
537 SSLException newSsle = new SSLHandshakeException(
538 "SSL handshake failed.");
539 newSsle.initCause(ssle);
540 throw newSsle;
541 }
542
543 IoBuffer outNetBuffer = fetchOutNetBuffer();
544 if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
545 writeFuture = new DefaultWriteFuture(session);
546 parent.filterWrite(nextFilter, session,
547 new DefaultWriteRequest(outNetBuffer, writeFuture));
548 }
549 }
550 } finally {
551 writingEncryptedData = false;
552 }
553
554 return writeFuture;
555 }
556
557 private void unwrap(NextFilter nextFilter) throws SSLException {
558
559 if (inNetBuffer != null) {
560 inNetBuffer.flip();
561 }
562
563 if (inNetBuffer == null || !inNetBuffer.hasRemaining()) {
564 return;
565 }
566
567 SSLEngineResult res = unwrap0();
568
569
570 if (inNetBuffer.hasRemaining()) {
571 inNetBuffer.compact();
572 } else {
573 inNetBuffer = null;
574 }
575
576 checkStatus(res);
577
578 renegotiateIfNeeded(nextFilter, res);
579 }
580
581 private SSLEngineResult.Status unwrapHandshake(NextFilter nextFilter) throws SSLException {
582
583 if (inNetBuffer != null) {
584 inNetBuffer.flip();
585 }
586
587 if (inNetBuffer == null || !inNetBuffer.hasRemaining()) {
588
589 return SSLEngineResult.Status.BUFFER_UNDERFLOW;
590 }
591
592 SSLEngineResult res = unwrap0();
593 handshakeStatus = res.getHandshakeStatus();
594
595 checkStatus(res);
596
597
598
599 if (handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED
600 && res.getStatus() == SSLEngineResult.Status.OK
601 && inNetBuffer.hasRemaining()) {
602 res = unwrap0();
603
604
605 if (inNetBuffer.hasRemaining()) {
606 inNetBuffer.compact();
607 } else {
608 inNetBuffer = null;
609 }
610
611 renegotiateIfNeeded(nextFilter, res);
612 } else {
613
614 if (inNetBuffer.hasRemaining()) {
615 inNetBuffer.compact();
616 } else {
617 inNetBuffer = null;
618 }
619 }
620
621 return res.getStatus();
622 }
623
624 private void renegotiateIfNeeded(NextFilter nextFilter, SSLEngineResult res)
625 throws SSLException {
626 if (res.getStatus() != SSLEngineResult.Status.CLOSED
627 && res.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW
628 && res.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
629
630 handshakeComplete = false;
631 handshakeStatus = res.getHandshakeStatus();
632 handshake(nextFilter);
633 }
634 }
635
636 private SSLEngineResult unwrap0() throws SSLException {
637 if (appBuffer == null) {
638 appBuffer = IoBuffer.allocate(inNetBuffer.remaining());
639 } else {
640 appBuffer.expand(inNetBuffer.remaining());
641 }
642
643 SSLEngineResult res;
644 do {
645 res = sslEngine.unwrap(inNetBuffer.buf(), appBuffer.buf());
646 if (res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
647 appBuffer.capacity(appBuffer.capacity() << 1);
648 appBuffer.limit(appBuffer.capacity());
649 continue;
650 }
651 } while ((res.getStatus() == SSLEngineResult.Status.OK || res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) &&
652 (handshakeComplete && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING ||
653 res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP));
654
655 return res;
656 }
657
658
659
660
661 private SSLEngineResult.HandshakeStatus doTasks() {
662
663
664
665
666 Runnable runnable;
667 while ((runnable = sslEngine.getDelegatedTask()) != null) {
668 runnable.run();
669 }
670 return sslEngine.getHandshakeStatus();
671 }
672
673
674
675
676
677
678
679
680 public static IoBuffer copy(ByteBuffer src) {
681 IoBuffer copy = IoBuffer.allocate(src.remaining());
682 copy.put(src);
683 copy.flip();
684 return copy;
685 }
686 }