View Javadoc
1   /*
2    * ====================================================================
3    * Licensed to the Apache Software Foundation (ASF) under one
4    * or more contributor license agreements.  See the NOTICE file
5    * distributed with this work for additional information
6    * regarding copyright ownership.  The ASF licenses this file
7    * to you under the Apache License, Version 2.0 (the
8    * "License"); you may not use this file except in compliance
9    * with the License.  You may obtain a copy of the License at
10   *
11   *   http://www.apache.org/licenses/LICENSE-2.0
12   *
13   * Unless required by applicable law or agreed to in writing,
14   * software distributed under the License is distributed on an
15   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16   * KIND, either express or implied.  See the License for the
17   * specific language governing permissions and limitations
18   * under the License.
19   * ====================================================================
20   *
21   * This software consists of voluntary contributions made by many
22   * individuals on behalf of the Apache Software Foundation.  For more
23   * information on the Apache Software Foundation, please see
24   * <http://www.apache.org/>.
25   *
26   */
27  package org.apache.hc.core5.reactor.ssl;
28  
29  
30  import static org.mockito.ArgumentMatchers.any;
31  import static org.mockito.Mockito.mock;
32  
33  import java.nio.ByteBuffer;
34  import java.security.Provider;
35  import java.util.concurrent.locks.ReentrantLock;
36  
37  import javax.net.ssl.SSLContext;
38  import javax.net.ssl.SSLContextSpi;
39  import javax.net.ssl.SSLEngine;
40  import javax.net.ssl.SSLEngineResult;
41  import javax.net.ssl.SSLException;
42  import javax.net.ssl.SSLSession;
43  
44  import org.apache.hc.core5.concurrent.FutureCallback;
45  import org.apache.hc.core5.function.Callback;
46  import org.apache.hc.core5.net.NamedEndpoint;
47  import org.apache.hc.core5.reactor.IOEventHandler;
48  import org.apache.hc.core5.reactor.IOSession;
49  import org.apache.hc.core5.util.Timeout;
50  import org.junit.jupiter.api.Assertions;
51  import org.junit.jupiter.api.BeforeEach;
52  import org.junit.jupiter.api.Test;
53  import org.mockito.Mockito;
54  
55  class SSLIOSessionTest {
56  
57      // Define common variables here, so you can easily modify them in each test
58      private NamedEndpoint targetEndpoint;
59      private IOSession ioSession;
60      private SSLMode sslMode;
61      private SSLContext sslContext;
62      private SSLBufferMode sslBufferMode;
63      private SSLSessionInitializer initializer;
64      private SSLSessionVerifier verifier;
65      private Timeout handshakeTimeout;
66      private Callback<SSLIOSession> sessionStartCallback;
67      private Callback<SSLIOSession> sessionEndCallback;
68      private FutureCallback<SSLSession> resultCallback;
69      private IOEventHandler ioEventHandler;
70      private SSLEngine mockSSLEngine;
71      private SSLSession sslSession;
72  
73      @BeforeEach
74      public void setUp() throws SSLException {
75          final String protocol = "TestProtocol";
76  
77          // Arrange
78          targetEndpoint = mock(NamedEndpoint.class);
79          ioSession = mock(IOSession.class);
80          sslMode = SSLMode.CLIENT;  // Use actual SSLMode
81  
82          //SSLContext sslContext = SSLContext.getDefault();
83          final SSLContextSpi sslContextSpi = mock(SSLContextSpi.class);
84          final Provider provider = mock(Provider.class);
85          sslContext = new TestSSLContext(sslContextSpi, provider, protocol);
86  
87          sslSession = mock(SSLSession.class);
88  
89          sslBufferMode = SSLBufferMode.STATIC;
90          initializer = mock(SSLSessionInitializer.class);
91          verifier = mock(SSLSessionVerifier.class);
92          handshakeTimeout = mock(Timeout.class);
93          sessionStartCallback = mock(Callback.class);
94          sessionEndCallback = mock(Callback.class);
95          resultCallback = mock(FutureCallback.class);
96          ioEventHandler = mock(IOEventHandler.class);
97  
98          // Mock behavior of targetEndpoint
99          Mockito.when(targetEndpoint.getHostName()).thenReturn("testHostName");
100         Mockito.when(targetEndpoint.getPort()).thenReturn(8080);
101 
102         Mockito.when(sslSession.getPacketBufferSize()).thenReturn(1024);
103         Mockito.when(sslSession.getApplicationBufferSize()).thenReturn(1024);
104 
105         // Mock behavior of ioSession
106         Mockito.when(ioSession.getEventMask()).thenReturn(1);
107         Mockito.when(ioSession.getLock()).thenReturn(new ReentrantLock());
108 
109         // Mock behavior of sslContext and SSLEngine
110         mockSSLEngine = mock(SSLEngine.class);
111         Mockito.when(sslContext.createSSLEngine(any(String.class), any(Integer.class))).thenReturn(mockSSLEngine);
112         Mockito.when(mockSSLEngine.getSession()).thenReturn(sslSession);
113         Mockito.when(mockSSLEngine.getHandshakeStatus()).thenReturn(SSLEngineResult.HandshakeStatus.NEED_WRAP);
114 
115         Mockito.when(ioSession.getHandler()).thenReturn(ioEventHandler);
116 
117         final SSLEngineResult mockResult = new SSLEngineResult(SSLEngineResult.Status.CLOSED,
118                 SSLEngineResult.HandshakeStatus.FINISHED,
119                 0, 464);
120         Mockito.when(mockSSLEngine.wrap(any(ByteBuffer.class), any(ByteBuffer.class))).thenReturn(mockResult);
121     }
122 
123 
124     @Test
125     void testConstructorWhenSSLEngineOk() {
126         final String protocol = "TestProtocol";
127         // Arrange
128         Mockito.when(mockSSLEngine.getApplicationProtocol()).thenReturn(protocol);
129 
130         // Act
131         final TestableSSLIOSession sslioSession = new TestableSSLIOSession(targetEndpoint, ioSession, sslMode, sslContext,
132                 sslBufferMode, initializer, verifier, handshakeTimeout, sessionStartCallback, sessionEndCallback,
133                 resultCallback);
134 
135         // Assert
136         Assertions.assertDoesNotThrow(() -> sslioSession.beginHandshake(ioSession));
137         Assertions.assertEquals(protocol, sslioSession.getTlsDetails().getApplicationProtocol());
138     }
139 
140     @Test
141     void testConstructorWhenSSLEngineThrowsException() {
142         final String protocol = "http/1.1";
143         // Arrange
144         Mockito.when(mockSSLEngine.getApplicationProtocol()).thenThrow(UnsupportedOperationException.class);
145 
146         // Act
147         final TestableSSLIOSession sslioSession = new TestableSSLIOSession(targetEndpoint, ioSession, sslMode, sslContext,
148                 sslBufferMode, initializer, verifier, handshakeTimeout, sessionStartCallback, sessionEndCallback,
149                 resultCallback);
150 
151         // Assert
152         Assertions.assertDoesNotThrow(() -> sslioSession.beginHandshake(ioSession));
153         Assertions.assertEquals(protocol, sslioSession.getTlsDetails().getApplicationProtocol());
154     }
155 
156 
157     static class TestSSLContext extends SSLContext {
158 
159         /**
160          * Creates an SSLContext object.
161          *
162          * @param contextSpi the delegate
163          * @param provider   the provider
164          * @param protocol   the protocol
165          */
166         protected TestSSLContext(final SSLContextSpi contextSpi, final Provider provider, final String protocol) {
167             super(contextSpi, provider, protocol);
168         }
169     }
170 
171     static class TestableSSLIOSession extends SSLIOSession {
172         TestableSSLIOSession(final NamedEndpoint targetEndpoint, final IOSession session, final SSLMode sslMode, final SSLContext sslContext,
173                              final SSLBufferMode sslBufferMode, final SSLSessionInitializer initializer, final SSLSessionVerifier verifier,
174                              final Timeout handshakeTimeout, final Callback<SSLIOSession> sessionStartCallback,
175                              final Callback<SSLIOSession> sessionEndCallback, final FutureCallback<SSLSession> resultCallback) {
176             super(targetEndpoint, session, sslMode, sslContext, sslBufferMode, initializer, verifier,
177                     handshakeTimeout, sessionStartCallback, sessionEndCallback, resultCallback);
178         }
179     }
180 
181 }
182 
183