View Javadoc

1   /*
2    *   @(#) $Id: SSLFilter.java 332218 2005-11-10 03:52:42Z trustin $
3    *
4    *   Copyright 2004 The Apache Software Foundation
5    *
6    *   Licensed under the Apache License, Version 2.0 (the "License");
7    *   you may not use this file except in compliance with the License.
8    *   You may obtain a copy of the License at
9    *
10   *       http://www.apache.org/licenses/LICENSE-2.0
11   *
12   *   Unless required by applicable law or agreed to in writing, software
13   *   distributed under the License is distributed on an "AS IS" BASIS,
14   *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   *   See the License for the specific language governing permissions and
16   *   limitations under the License.
17   *
18   */
19  package org.apache.mina.io.filter;
20  
21  import javax.net.ssl.SSLContext;
22  import javax.net.ssl.SSLEngine;
23  import javax.net.ssl.SSLException;
24  import javax.net.ssl.SSLHandshakeException;
25  import javax.net.ssl.SSLSession;
26  
27  import org.apache.mina.common.ByteBuffer;
28  import org.apache.mina.io.IoFilterAdapter;
29  import org.apache.mina.io.IoHandler;
30  import org.apache.mina.io.IoSession;
31  import org.slf4j.Logger;
32  import org.slf4j.LoggerFactory;
33  
34  /***
35   * An SSL filter that encrypts and decrypts the data exchanged in the session.
36   * This filter uses an {@link SSLEngine} which was introduced in Java 5, so 
37   * Java version 5 or above is mandatory to use this filter. And please note that
38   * this filter only works for TCP/IP connections.
39   * <p>
40   * This filter logs debug information using {@link Logger}.
41   * 
42   * <h2>Implementing StartTLS</h2>
43   * <p>
44   * You can use {@link #DISABLE_ENCRYPTION_ONCE} attribute to implement StartTLS:
45   * <pre>
46   * public void messageReceived(ProtocolSession session, Object message) {
47   *    if (message instanceof MyStartTLSRequest) {
48   *        // Insert SSLFilter to get ready for handshaking
49   *        IoSession ioSession = ((IoProtocolSession) session).getIoSession();
50   *        ioSession.getFilterChain().addLast(sslFilter);
51   *
52   *        // Disable encryption temporarilly.
53   *        // This attribute will be removed by SSLFilter
54   *        // inside the Session.write() call below.
55   *        session.setAttribute(SSLFilter.DISABLE_ENCRYPTION_ONCE, Boolean.TRUE);
56   *
57   *        // Write StartTLSResponse which won't be encrypted.
58   *        session.write(new MyStartTLSResponse(OK));
59   *        
60   *        // Now DISABLE_ENCRYPTION_ONCE attribute is cleared.
61   *        assert session.getAttribute(SSLFilter.DISABLE_ENCRYPTION_ONCE) == null;
62   *    }
63   * }
64   * </pre>
65   * 
66   * @author The Apache Directory Project (dev@directory.apache.org)
67   * @version $Rev: 332218 $, $Date: 2005-11-10 12:52:42 +0900 $
68   */
69  public class SSLFilter extends IoFilterAdapter
70  {
71      /***
72       * A session attribute key that stores underlying {@link SSLSession}
73       * for each session.
74       */
75      public static final String SSL_SESSION = SSLFilter.class.getName() + ".SSLSession";
76      
77      /***
78       * A session attribute key that makes next one write request bypass
79       * this filter (not encrypting the data).  This is a marker attribute,
80       * which means that you can put whatever as its value. ({@link Boolean#TRUE}
81       * is preferred.)  The attribute is automatically removed from the session
82       * attribute map as soon as {@link IoSession#write(ByteBuffer, Object)} is
83       * invoked, and therefore should be put again if you want to make more
84       * messages bypass this filter.  This is especially useful when you
85       * implement StartTLS.   
86       */
87      public static final String DISABLE_ENCRYPTION_ONCE = SSLFilter.class.getName() + ".DisableEncryptionOnce";
88      
89      private static final String SSL_HANDLER = SSLFilter.class.getName() + ".SSLHandler";
90  
91      private static final Logger log = LoggerFactory.getLogger( SSLFilter.class );
92  
93      /***
94       * A marker which is passed with {@link IoHandler#dataWritten(IoSession, Object)}
95       * when <tt>SSLFilter</tt> writes data other then user actually requested.
96       */
97      private static final Object SSL_MARKER = new Object()
98      {
99          public String toString()
100         {
101             return "SSL_MARKER";
102         }
103     };
104     
105     // SSL Context
106     private SSLContext sslContext;
107 
108     private boolean client;
109     private boolean needClientAuth;
110     private boolean wantClientAuth;
111     private String[] enabledCipherSuites;
112     private String[] enabledProtocols;
113 
114     /***
115      * Creates a new SSL filter using the specified {@link SSLContext}.
116      */
117     public SSLFilter( SSLContext sslContext )
118     {
119         if( sslContext == null )
120         {
121             throw new NullPointerException( "sslContext" );
122         }
123 
124         this.sslContext = sslContext;
125     }
126     
127     /***
128      * Returns the underlying {@link SSLSession} for the specified session.
129      * 
130      * @return <tt>null</tt> if no {@link SSLSession} is initialized yet.
131      */
132     public SSLSession getSSLSession( IoSession session )
133     {
134         return ( SSLSession ) session.getAttribute( SSL_SESSION );
135     }
136 
137     /***
138      * Returns <tt>true</tt> if the engine is set to use client mode
139      * when handshaking.
140      */
141     public boolean isUseClientMode()
142     {
143         return client;
144     }
145     
146     /***
147      * Configures the engine to use client (or server) mode when handshaking.
148      */
149     public void setUseClientMode( boolean clientMode )
150     {
151         this.client = clientMode;
152     }
153     
154     /***
155      * Returns <tt>true</tt> if the engine will <em>require</em> client authentication.
156      * This option is only useful to engines in the server mode.
157      */
158     public boolean isNeedClientAuth()
159     {
160         return needClientAuth;
161     }
162 
163     /***
164      * Configures the engine to <em>require</em> client authentication.
165      * This option is only useful for engines in the server mode.
166      */
167     public void setNeedClientAuth( boolean needClientAuth )
168     {
169         this.needClientAuth = needClientAuth;
170     }
171     
172     
173     /***
174      * Returns <tt>true</tt> if the engine will <em>request</em> client authentication.
175      * This option is only useful to engines in the server mode.
176      */
177     public boolean isWantClientAuth()
178     {
179         return wantClientAuth;
180     }
181     
182     /***
183      * Configures the engine to <em>request</em> client authentication.
184      * This option is only useful for engines in the server mode.
185      */
186     public void setWantClientAuth( boolean wantClientAuth )
187     {
188         this.wantClientAuth = wantClientAuth;
189     }
190     
191     /***
192      * Returns the list of cipher suites to be enabled when {@link SSLEngine}
193      * is initialized.
194      * 
195      * @return <tt>null</tt> means 'use {@link SSLEngine}'s default.'
196      */
197     public String[] getEnabledCipherSuites()
198     {
199         return enabledCipherSuites;
200     }
201     
202     /***
203      * Sets the list of cipher suites to be enabled when {@link SSLEngine}
204      * is initialized.
205      * 
206      * @param cipherSuites <tt>null</tt> means 'use {@link SSLEngine}'s default.'
207      */
208     public void setEnabledCipherSuites( String[] cipherSuites )
209     {
210         this.enabledCipherSuites = cipherSuites;
211     }
212 
213     /***
214      * Returns the list of protocols to be enabled when {@link SSLEngine}
215      * is initialized.
216      * 
217      * @return <tt>null</tt> means 'use {@link SSLEngine}'s default.'
218      */
219     public String[] getEnabledProtocols()
220     {
221         return enabledProtocols;
222     }
223     
224     /***
225      * Sets the list of protocols to be enabled when {@link SSLEngine}
226      * is initialized.
227      * 
228      * @param protocols <tt>null</tt> means 'use {@link SSLEngine}'s default.'
229      */
230     public void setEnabledProtocols( String[] protocols )
231     {
232         this.enabledProtocols = protocols;
233     }
234 
235     // IoFilter impl.
236 
237     public void sessionOpened( NextFilter nextFilter, IoSession session ) throws SSLException
238     {
239         // Create an SSL handler
240         createSSLSessionHandler( nextFilter, session );
241         nextFilter.sessionOpened( session );
242     }
243 
244     public void sessionClosed( NextFilter nextFilter, IoSession session ) throws SSLException
245     {
246         SSLHandler sslHandler = getSSLSessionHandler( session );
247         if( log.isDebugEnabled() )
248         {
249             log.debug( session + " Closed: " + sslHandler );
250         }
251         if( sslHandler != null )
252         {
253             synchronized( sslHandler )
254             {
255                // Start SSL shutdown process
256                try
257                {
258                   // shut down
259                   sslHandler.shutdown();
260                   
261                   // there might be data to write out here?
262                   writeNetBuffer( nextFilter, session, sslHandler );
263                }
264                finally
265                {
266                   // notify closed session
267                   nextFilter.sessionClosed( session );
268                   
269                   // release buffers
270                   sslHandler.release();
271                   removeSSLSessionHandler( session );
272                }
273             }
274         }
275     }
276    
277     public void dataRead( NextFilter nextFilter, IoSession session,
278                           ByteBuffer buf ) throws SSLException
279     {
280         SSLHandler sslHandler = createSSLSessionHandler( nextFilter, session );
281         if( sslHandler != null )
282         {
283             if( log.isDebugEnabled() )
284             {
285                 log.debug( session + " Data Read: " + sslHandler + " (" + buf+ ')' );
286             }
287             synchronized( sslHandler )
288             {
289                 try
290                 {
291                     // forward read encrypted data to SSL handler
292                     sslHandler.dataRead( nextFilter, buf.buf() );
293 
294                     // Handle data to be forwarded to application or written to net
295                     handleSSLData( nextFilter, session, sslHandler );
296 
297                     if( sslHandler.isClosed() )
298                     {
299                         if( log.isDebugEnabled() )
300                         {
301                             log.debug( session + " SSL Session closed. Closing connection.." );
302                         }
303                         session.close();
304                     }
305                 }
306                 catch( SSLException ssle )
307                 {
308                     if( !sslHandler.isInitialHandshakeComplete() )
309                     {
310                         SSLException newSSLE = new SSLHandshakeException(
311                                 "Initial SSL handshake failed." );
312                         newSSLE.initCause( ssle );
313                         ssle = newSSLE;
314                     }
315 
316                     throw ssle;
317                 }
318             }
319         }
320         else
321         {
322             nextFilter.dataRead( session, buf );
323         }
324     }
325 
326     public void dataWritten( NextFilter nextFilter, IoSession session,
327                             Object marker )
328     {
329         if( marker != SSL_MARKER )
330         {
331             nextFilter.dataWritten( session, marker );
332         }
333     }
334 
335     public void filterWrite( NextFilter nextFilter, IoSession session, ByteBuffer buf, Object marker ) throws SSLException
336     {
337         // Don't encrypt the data if encryption is disabled.
338         if( session.getAttribute(DISABLE_ENCRYPTION_ONCE) != null )
339         {
340             // Remove the marker attribute because it is temporary.
341             session.removeAttribute(DISABLE_ENCRYPTION_ONCE);
342             nextFilter.filterWrite( session, buf, marker );
343             return;
344         }
345         
346         // Otherwise, encrypt the buffer.
347         SSLHandler handler = createSSLSessionHandler( nextFilter, session );
348         if( log.isDebugEnabled() )
349         {
350             log.debug( session + " Filtered Write: " + handler );
351         }
352 
353         synchronized( handler )
354         {
355             if( handler.isWritingEncryptedData() )
356             {
357                 // data already encrypted; simply return buffer
358                 if( log.isDebugEnabled() )
359                 {
360                     log.debug( session + "   already encrypted: " + buf );
361                 }
362                 nextFilter.filterWrite( session, buf, marker );
363                 return;
364             }
365             
366             if( handler.isInitialHandshakeComplete() )
367             {
368                 // SSL encrypt
369                 if( log.isDebugEnabled() )
370                 {
371                     log.debug( session + " encrypt: " + buf );
372                 }
373                 handler.encrypt( buf.buf() );
374                 ByteBuffer encryptedBuffer = copy( handler
375                         .getOutNetBuffer() );
376 
377                 if( log.isDebugEnabled() )
378                 {
379                     log.debug( session + " encrypted buf: " + encryptedBuffer);
380                 }
381                 buf.release();
382                 nextFilter.filterWrite( session, encryptedBuffer, marker );
383                 return;
384             }
385             else
386             {
387                 if( !session.isConnected() )
388                 {
389                     if( log.isDebugEnabled() )
390                     {
391                         log.debug( session + " Write request on closed session." );
392                     }
393                 }
394                 else
395                 {
396                     if( log.isDebugEnabled() )
397                     {
398                         log.debug( session + " Handshaking is not complete yet. Buffering write request." );
399                     }
400                     handler.scheduleWrite( nextFilter, buf, marker );
401                 }
402             }
403         }
404     }
405 
406     // Utiliities
407 
408     private void handleSSLData( NextFilter nextFilter, IoSession session,
409                                SSLHandler handler ) throws SSLException
410     {
411         // Flush any buffered write requests occurred before handshaking.
412         if( handler.isInitialHandshakeComplete() )
413         {
414             handler.flushScheduledWrites();
415         }
416 
417         // Write encrypted data to be written (if any)
418         writeNetBuffer( nextFilter, session, handler );
419 
420         // handle app. data read (if any)
421         handleAppDataRead( nextFilter, session, handler );
422     }
423 
424     private void handleAppDataRead( NextFilter nextFilter, IoSession session,
425                                    SSLHandler sslHandler )
426     {
427         if( log.isDebugEnabled() )
428         {
429             log.debug( session + " appBuffer: " + sslHandler.getAppBuffer() );
430         }
431         if( sslHandler.getAppBuffer().hasRemaining() )
432         {
433             // forward read app data
434             ByteBuffer readBuffer = copy( sslHandler.getAppBuffer() );
435             if( log.isDebugEnabled() )
436             {
437                 log.debug( session + " app data read: " + readBuffer + " (" + readBuffer.getHexDump() + ')' );
438             }
439             nextFilter.dataRead( session, readBuffer );
440         }
441     }
442 
443     void writeNetBuffer( NextFilter nextFilter, IoSession session, SSLHandler sslHandler )
444             throws SSLException
445     {
446         // Check if any net data needed to be writen
447         if( !sslHandler.getOutNetBuffer().hasRemaining() )
448         {
449             // no; bail out
450             return;
451         }
452 
453         // write net data
454 
455         // set flag that we are writing encrypted data
456         // (used in filterWrite() above)
457         synchronized( sslHandler )
458         {
459             sslHandler.setWritingEncryptedData( true );
460         }
461 
462         try
463         {
464             if( log.isDebugEnabled() )
465             {
466                 log.debug( session + " write outNetBuffer: " +
467                                    sslHandler.getOutNetBuffer() );
468             }
469             ByteBuffer writeBuffer = copy( sslHandler.getOutNetBuffer() );
470             if( log.isDebugEnabled() )
471             {
472                 log.debug( session + " session write: " + writeBuffer );
473             }
474             //debug("outNetBuffer (after copy): {0}", sslHandler.getOutNetBuffer());
475             filterWrite( nextFilter, session, writeBuffer, SSL_MARKER );
476 
477             // loop while more writes required to complete handshake
478             while( sslHandler.needToCompleteInitialHandshake() )
479             {
480                 try
481                 {
482                     sslHandler.continueHandshake( nextFilter );
483                 }
484                 catch( SSLException ssle )
485                 {
486                     SSLException newSSLE = new SSLHandshakeException(
487                             "Initial SSL handshake failed." );
488                     newSSLE.initCause( ssle );
489                     throw newSSLE;
490                 }
491                 if( sslHandler.getOutNetBuffer().hasRemaining() )
492                 {
493                     if( log.isDebugEnabled() )
494                     {
495                         log.debug( session + " write outNetBuffer2: " +
496                                            sslHandler.getOutNetBuffer() );
497                     }
498                     ByteBuffer writeBuffer2 = copy( sslHandler
499                             .getOutNetBuffer() );
500                     filterWrite( nextFilter, session, writeBuffer2, SSL_MARKER );
501                 }
502             }
503         }
504         finally
505         {
506             synchronized( sslHandler )
507             {
508                 sslHandler.setWritingEncryptedData( false );
509             }
510         }
511     }
512 
513     /***
514      * Creates a new Mina byte buffer that is a deep copy of the remaining bytes
515      * in the given buffer (between index buf.position() and buf.limit())
516      *
517      * @param src the buffer to copy
518      * @return the new buffer, ready to read from
519      */
520     private static ByteBuffer copy( java.nio.ByteBuffer src )
521     {
522         ByteBuffer copy = ByteBuffer.allocate( src.remaining() );
523         copy.put( src );
524         copy.flip();
525         return copy;
526     }
527 
528     // Utilities to mainpulate SSLHandler based on IoSession
529 
530     private SSLHandler createSSLSessionHandler( NextFilter nextFilter, IoSession session ) throws SSLException
531     {
532         SSLHandler handler = getSSLSessionHandler( session );
533         if( handler == null )
534         {
535             synchronized( session )
536             {
537                 handler = getSSLSessionHandler( session );
538                 if( handler == null )
539                 {
540                     boolean done = false;
541                     try
542                     {
543                         handler =
544                             new SSLHandler( this, sslContext, session );
545                         session.setAttribute( SSL_HANDLER, handler );
546                         handler.doHandshake( nextFilter );
547                         done = true;
548                     }
549                     finally 
550                     {
551                         if( !done )
552                         {
553                             session.removeAttribute( SSL_HANDLER );
554                         }
555                     }
556                 }
557             }
558         }
559         
560         return handler;
561     }
562 
563     private SSLHandler getSSLSessionHandler( IoSession session )
564     {
565         return ( SSLHandler ) session.getAttribute( SSL_HANDLER );
566     }
567 
568     private void removeSSLSessionHandler( IoSession session )
569     {
570         session.removeAttribute( SSL_HANDLER );
571     }
572 }