1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.ssl;
21
22 import java.net.InetSocketAddress;
23 import java.nio.ByteBuffer;
24 import java.util.Queue;
25 import java.util.concurrent.ConcurrentLinkedQueue;
26
27 import javax.net.ssl.SSLContext;
28 import javax.net.ssl.SSLEngine;
29 import javax.net.ssl.SSLEngineResult;
30 import javax.net.ssl.SSLException;
31 import javax.net.ssl.SSLHandshakeException;
32
33 import org.apache.mina.core.buffer.IoBuffer;
34 import org.apache.mina.core.filterchain.IoFilterEvent;
35 import org.apache.mina.core.filterchain.IoFilter.NextFilter;
36 import org.apache.mina.core.future.DefaultWriteFuture;
37 import org.apache.mina.core.future.WriteFuture;
38 import org.apache.mina.core.session.IoEventType;
39 import org.apache.mina.core.session.IoSession;
40 import org.apache.mina.core.write.DefaultWriteRequest;
41 import org.apache.mina.core.write.WriteRequest;
42 import org.apache.mina.util.CircularQueue;
43 import org.slf4j.Logger;
44 import org.slf4j.LoggerFactory;
45
46
47
48
49
50
51
52
53
54
55
56 class SslHandler {
57
58 private final static Logger LOGGER = LoggerFactory.getLogger(SslHandler.class);
59 private final SslFilter parent;
60 private final SSLContext sslContext;
61 private final IoSession session;
62 private final Queue<IoFilterEvent> preHandshakeEventQueue = new CircularQueue<IoFilterEvent>();
63 private final Queue<IoFilterEvent> filterWriteEventQueue = new ConcurrentLinkedQueue<IoFilterEvent>();
64 private final Queue<IoFilterEvent> messageReceivedEventQueue = new ConcurrentLinkedQueue<IoFilterEvent>();
65 private SSLEngine sslEngine;
66
67
68
69
70 private IoBuffer inNetBuffer;
71
72
73
74
75 private IoBuffer outNetBuffer;
76
77
78
79
80 private IoBuffer appBuffer;
81
82
83
84
85 private final IoBuffer emptyBuffer = IoBuffer.allocate(0);
86
87 private SSLEngineResult.HandshakeStatus handshakeStatus;
88 private boolean initialHandshakeComplete;
89 private boolean handshakeComplete;
90 private boolean writingEncryptedData;
91
92
93
94
95
96
97
98 public SslHandler(SslFilter parent, SSLContext sslContext, IoSession session)
99 throws SSLException {
100 this.parent = parent;
101 this.session = session;
102 this.sslContext = sslContext;
103 init();
104 }
105
106
107
108
109
110
111 public void init() throws SSLException {
112 if (sslEngine != null) {
113
114 return;
115 }
116
117 InetSocketAddress peer = (InetSocketAddress) session
118 .getAttribute(SslFilter.PEER_ADDRESS);
119
120
121 if (peer == null) {
122 sslEngine = sslContext.createSSLEngine();
123 } else {
124 sslEngine = sslContext.createSSLEngine(peer.getHostName(), peer.getPort());
125 }
126
127
128 sslEngine.setUseClientMode(parent.isUseClientMode());
129
130
131 if (parent.isWantClientAuth()) {
132 sslEngine.setWantClientAuth(true);
133 }
134
135 if (parent.isNeedClientAuth()) {
136 sslEngine.setNeedClientAuth(true);
137 }
138
139 if (parent.getEnabledCipherSuites() != null) {
140 sslEngine.setEnabledCipherSuites(parent.getEnabledCipherSuites());
141 }
142
143 if (parent.getEnabledProtocols() != null) {
144 sslEngine.setEnabledProtocols(parent.getEnabledProtocols());
145 }
146
147
148 sslEngine.beginHandshake();
149
150
151 handshakeStatus = sslEngine.getHandshakeStatus();
152
153 handshakeComplete = false;
154 initialHandshakeComplete = false;
155 writingEncryptedData = false;
156 }
157
158
159
160
161 public void destroy() {
162 if (sslEngine == null) {
163 return;
164 }
165
166
167 try {
168 sslEngine.closeInbound();
169 } catch (SSLException e) {
170 LOGGER.debug(
171 "Unexpected exception from SSLEngine.closeInbound().", e);
172 }
173
174
175 if (outNetBuffer != null) {
176 outNetBuffer.capacity(sslEngine.getSession().getPacketBufferSize());
177 } else {
178 createOutNetBuffer(0);
179 }
180 try {
181 do {
182 outNetBuffer.clear();
183 } while (sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf()).bytesProduced() > 0);
184 } catch (SSLException e) {
185
186 } finally {
187 destroyOutNetBuffer();
188 }
189
190 sslEngine.closeOutbound();
191 sslEngine = null;
192
193 preHandshakeEventQueue.clear();
194 }
195
196 private void destroyOutNetBuffer() {
197 outNetBuffer.free();
198 outNetBuffer = null;
199 }
200
201 public SslFilter getParent() {
202 return parent;
203 }
204
205 public IoSession getSession() {
206 return session;
207 }
208
209
210
211
212 public boolean isWritingEncryptedData() {
213 return writingEncryptedData;
214 }
215
216
217
218
219 public boolean isHandshakeComplete() {
220 return handshakeComplete;
221 }
222
223 public boolean isInboundDone() {
224 return sslEngine == null || sslEngine.isInboundDone();
225 }
226
227 public boolean isOutboundDone() {
228 return sslEngine == null || sslEngine.isOutboundDone();
229 }
230
231
232
233
234 public boolean needToCompleteHandshake() {
235 return handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP && !isInboundDone();
236 }
237
238 public void schedulePreHandshakeWriteRequest(NextFilter nextFilter,
239 WriteRequest writeRequest) {
240 preHandshakeEventQueue.add(new IoFilterEvent(nextFilter,
241 IoEventType.WRITE, session, writeRequest));
242 }
243
244 public void flushPreHandshakeEvents() throws SSLException {
245 IoFilterEvent scheduledWrite;
246
247 while ((scheduledWrite = preHandshakeEventQueue.poll()) != null) {
248 parent.filterWrite(scheduledWrite.getNextFilter(), session,
249 (WriteRequest) scheduledWrite.getParameter());
250 }
251 }
252
253 public void scheduleFilterWrite(NextFilter nextFilter, WriteRequest writeRequest) {
254 filterWriteEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.WRITE, session, writeRequest));
255 }
256
257 public void scheduleMessageReceived(NextFilter nextFilter, Object message) {
258 messageReceivedEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.MESSAGE_RECEIVED, session, message));
259 }
260
261 public void flushScheduledEvents() {
262
263 if (Thread.holdsLock(this)) {
264 return;
265 }
266
267 IoFilterEvent e;
268
269
270
271 synchronized (this) {
272 while ((e = filterWriteEventQueue.poll()) != null) {
273 e.getNextFilter().filterWrite(session, (WriteRequest) e.getParameter());
274 }
275 }
276
277 while ((e = messageReceivedEventQueue.poll()) != null) {
278 e.getNextFilter().messageReceived(session, e.getParameter());
279 }
280 }
281
282
283
284
285
286
287
288
289
290
291 public void messageReceived(NextFilter nextFilter, ByteBuffer buf) throws SSLException {
292
293 if (inNetBuffer == null) {
294 inNetBuffer = IoBuffer.allocate(buf.remaining()).setAutoExpand(true);
295 }
296
297 inNetBuffer.put(buf);
298 if (!handshakeComplete) {
299 handshake(nextFilter);
300 } else {
301 decrypt(nextFilter);
302 }
303
304 if (isInboundDone()) {
305
306 int inNetBufferPosition = inNetBuffer == null? 0 : inNetBuffer.position();
307 buf.position(buf.position() - inNetBufferPosition);
308 inNetBuffer = null;
309 }
310 }
311
312
313
314
315
316
317 public IoBuffer fetchAppBuffer() {
318 IoBuffer appBuffer = this.appBuffer.flip();
319 this.appBuffer = null;
320 return appBuffer;
321 }
322
323
324
325
326
327
328 public IoBuffer fetchOutNetBuffer() {
329 IoBuffer answer = outNetBuffer;
330 if (answer == null) {
331 return emptyBuffer;
332 }
333
334 outNetBuffer = null;
335 return answer.shrink();
336 }
337
338
339
340
341
342
343
344 public void encrypt(ByteBuffer src) throws SSLException {
345 if (!handshakeComplete) {
346 throw new IllegalStateException();
347 }
348
349 if (!src.hasRemaining()) {
350 if (outNetBuffer == null) {
351 outNetBuffer = emptyBuffer;
352 }
353 return;
354 }
355
356 createOutNetBuffer(src.remaining());
357
358
359 while (src.hasRemaining()) {
360
361 SSLEngineResult result = sslEngine.wrap(src, outNetBuffer.buf());
362 if (result.getStatus() == SSLEngineResult.Status.OK) {
363 if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
364 doTasks();
365 }
366 } else if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
367 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
368 outNetBuffer.limit(outNetBuffer.capacity());
369 } else {
370 throw new SSLException("SSLEngine error during encrypt: "
371 + result.getStatus() + " src: " + src
372 + "outNetBuffer: " + outNetBuffer);
373 }
374 }
375
376 outNetBuffer.flip();
377 }
378
379
380
381
382
383
384
385
386 public boolean closeOutbound() throws SSLException {
387 if (sslEngine == null || sslEngine.isOutboundDone()) {
388 return false;
389 }
390
391 sslEngine.closeOutbound();
392
393 createOutNetBuffer(0);
394 SSLEngineResult result;
395 for (;;) {
396 result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
397 if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
398 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
399 outNetBuffer.limit(outNetBuffer.capacity());
400 } else {
401 break;
402 }
403 }
404
405 if (result.getStatus() != SSLEngineResult.Status.CLOSED) {
406 throw new SSLException("Improper close state: " + result);
407 }
408 outNetBuffer.flip();
409 return true;
410 }
411
412
413
414
415
416
417 private void decrypt(NextFilter nextFilter) throws SSLException {
418
419 if (!handshakeComplete) {
420 throw new IllegalStateException();
421 }
422
423 unwrap(nextFilter);
424 }
425
426
427
428
429
430 private void checkStatus(SSLEngineResult res)
431 throws SSLException {
432
433 SSLEngineResult.Status status = res.getStatus();
434
435
436
437
438
439
440
441
442
443 if (status != SSLEngineResult.Status.OK
444 && status != SSLEngineResult.Status.CLOSED
445 && status != SSLEngineResult.Status.BUFFER_UNDERFLOW) {
446 throw new SSLException("SSLEngine error during decrypt: " + status
447 + " inNetBuffer: " + inNetBuffer + "appBuffer: "
448 + appBuffer);
449 }
450 }
451
452
453
454
455 public void handshake(NextFilter nextFilter) throws SSLException {
456 for (;;) {
457 switch (handshakeStatus) {
458 case FINISHED :
459 session.setAttribute(
460 SslFilter.SSL_SESSION, sslEngine.getSession());
461 handshakeComplete = true;
462
463 if (!initialHandshakeComplete
464 && session.containsAttribute(SslFilter.USE_NOTIFICATION)) {
465
466
467 initialHandshakeComplete = true;
468 scheduleMessageReceived(nextFilter,
469 SslFilter.SESSION_SECURED);
470 }
471
472 return;
473
474 case NEED_TASK :
475 handshakeStatus = doTasks();
476 break;
477
478 case NEED_UNWRAP :
479
480 SSLEngineResult.Status status = unwrapHandshake(nextFilter);
481
482 if (status == SSLEngineResult.Status.BUFFER_UNDERFLOW &&
483 handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED ||
484 isInboundDone()) {
485
486 return;
487 }
488
489 break;
490
491 case NEED_WRAP :
492
493
494 if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
495 return;
496 }
497
498 SSLEngineResult result;
499 createOutNetBuffer(0);
500
501 for (;;) {
502 result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
503 if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
504 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
505 outNetBuffer.limit(outNetBuffer.capacity());
506 } else {
507 break;
508 }
509 }
510
511 outNetBuffer.flip();
512 handshakeStatus = result.getHandshakeStatus();
513 writeNetBuffer(nextFilter);
514 break;
515
516 default :
517 throw new IllegalStateException("Invalid Handshaking State"
518 + handshakeStatus);
519 }
520 }
521 }
522
523 private void createOutNetBuffer(int expectedRemaining) {
524
525
526 int capacity = Math.max(
527 expectedRemaining,
528 sslEngine.getSession().getPacketBufferSize());
529
530 if (outNetBuffer != null) {
531 outNetBuffer.capacity(capacity);
532 } else {
533 outNetBuffer = IoBuffer.allocate(capacity).minimumCapacity(0);
534 }
535 }
536
537 public WriteFuture writeNetBuffer(NextFilter nextFilter)
538 throws SSLException {
539
540 if (outNetBuffer == null || !outNetBuffer.hasRemaining()) {
541
542 return null;
543 }
544
545
546
547 writingEncryptedData = true;
548
549
550 WriteFuture writeFuture = null;
551
552 try {
553 IoBuffer writeBuffer = fetchOutNetBuffer();
554 writeFuture = new DefaultWriteFuture(session);
555 parent.filterWrite(nextFilter, session, new DefaultWriteRequest(
556 writeBuffer, writeFuture));
557
558
559 while (needToCompleteHandshake()) {
560 try {
561 handshake(nextFilter);
562 } catch (SSLException ssle) {
563 SSLException newSsle = new SSLHandshakeException(
564 "SSL handshake failed.");
565 newSsle.initCause(ssle);
566 throw newSsle;
567 }
568
569 IoBuffer outNetBuffer = fetchOutNetBuffer();
570 if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
571 writeFuture = new DefaultWriteFuture(session);
572 parent.filterWrite(nextFilter, session,
573 new DefaultWriteRequest(outNetBuffer, writeFuture));
574 }
575 }
576 } finally {
577 writingEncryptedData = false;
578 }
579
580 return writeFuture;
581 }
582
583 private void unwrap(NextFilter nextFilter) throws SSLException {
584
585 if (inNetBuffer != null) {
586 inNetBuffer.flip();
587 }
588
589 if (inNetBuffer == null || !inNetBuffer.hasRemaining()) {
590 return;
591 }
592
593 SSLEngineResult res = unwrap0();
594
595
596 if (inNetBuffer.hasRemaining()) {
597 inNetBuffer.compact();
598 } else {
599 inNetBuffer = null;
600 }
601
602 checkStatus(res);
603
604 renegotiateIfNeeded(nextFilter, res);
605 }
606
607 private SSLEngineResult.Status unwrapHandshake(NextFilter nextFilter) throws SSLException {
608
609 if (inNetBuffer != null) {
610 inNetBuffer.flip();
611 }
612
613 if (inNetBuffer == null || !inNetBuffer.hasRemaining()) {
614
615 return SSLEngineResult.Status.BUFFER_UNDERFLOW;
616 }
617
618 SSLEngineResult res = unwrap0();
619 handshakeStatus = res.getHandshakeStatus();
620
621 checkStatus(res);
622
623
624
625 if (handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED
626 && res.getStatus() == SSLEngineResult.Status.OK
627 && inNetBuffer.hasRemaining()) {
628 res = unwrap0();
629
630
631 if (inNetBuffer.hasRemaining()) {
632 inNetBuffer.compact();
633 } else {
634 inNetBuffer = null;
635 }
636
637 renegotiateIfNeeded(nextFilter, res);
638 } else {
639
640 if (inNetBuffer.hasRemaining()) {
641 inNetBuffer.compact();
642 } else {
643 inNetBuffer = null;
644 }
645 }
646
647 return res.getStatus();
648 }
649
650 private void renegotiateIfNeeded(NextFilter nextFilter, SSLEngineResult res)
651 throws SSLException {
652 if (res.getStatus() != SSLEngineResult.Status.CLOSED
653 && res.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW
654 && res.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
655
656 handshakeComplete = false;
657 handshakeStatus = res.getHandshakeStatus();
658 handshake(nextFilter);
659 }
660 }
661
662 private SSLEngineResult unwrap0() throws SSLException {
663 if (appBuffer == null) {
664 appBuffer = IoBuffer.allocate(inNetBuffer.remaining());
665 } else {
666 appBuffer.expand(inNetBuffer.remaining());
667 }
668
669 SSLEngineResult res;
670 do {
671 res = sslEngine.unwrap(inNetBuffer.buf(), appBuffer.buf());
672 if (res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
673 appBuffer.capacity(appBuffer.capacity() << 1);
674 appBuffer.limit(appBuffer.capacity());
675 continue;
676 }
677 } while ((res.getStatus() == SSLEngineResult.Status.OK || res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) &&
678 (handshakeComplete && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING ||
679 res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP));
680
681 return res;
682 }
683
684
685
686
687 private SSLEngineResult.HandshakeStatus doTasks() {
688
689
690
691
692 Runnable runnable;
693 while ((runnable = sslEngine.getDelegatedTask()) != null) {
694
695 runnable.run();
696 }
697 return sslEngine.getHandshakeStatus();
698 }
699
700
701
702
703
704
705
706
707 public static IoBuffer copy(ByteBuffer src) {
708 IoBuffer copy = IoBuffer.allocate(src.remaining());
709 copy.put(src);
710 copy.flip();
711 return copy;
712 }
713 }