001/*
002 *  Licensed to the Apache Software Foundation (ASF) under one
003 *  or more contributor license agreements.  See the NOTICE file
004 *  distributed with this work for additional information
005 *  regarding copyright ownership.  The ASF licenses this file
006 *  to you under the Apache License, Version 2.0 (the
007 *  "License"); you may not use this file except in compliance
008 *  with the License.  You may obtain a copy of the License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 *  Unless required by applicable law or agreed to in writing,
013 *  software distributed under the License is distributed on an
014 *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 *  KIND, either express or implied.  See the License for the
016 *  specific language governing permissions and limitations
017 *  under the License.
018 *
019 */
020package org.apache.mina.filter.logging;
021
022import static org.junit.Assert.assertEquals;
023import static org.junit.Assert.assertNotNull;
024import static org.junit.Assert.assertNull;
025import static org.junit.Assert.fail;
026
027import java.io.IOException;
028import java.net.InetSocketAddress;
029import java.net.SocketAddress;
030import java.util.ArrayList;
031import java.util.Collections;
032import java.util.HashSet;
033import java.util.List;
034import java.util.Set;
035import java.util.concurrent.CountDownLatch;
036import java.util.concurrent.ExecutorService;
037import java.util.concurrent.TimeUnit;
038
039import org.apache.log4j.AppenderSkeleton;
040import org.apache.log4j.Level;
041import org.apache.log4j.spi.LoggingEvent;
042import org.apache.mina.core.buffer.IoBuffer;
043import org.apache.mina.core.filterchain.DefaultIoFilterChainBuilder;
044import org.apache.mina.core.filterchain.IoFilterAdapter;
045import org.apache.mina.core.future.ConnectFuture;
046import org.apache.mina.core.service.IoHandlerAdapter;
047import org.apache.mina.core.session.IdleStatus;
048import org.apache.mina.core.session.IoSession;
049import org.apache.mina.filter.codec.ProtocolCodecFactory;
050import org.apache.mina.filter.codec.ProtocolCodecFilter;
051import org.apache.mina.filter.codec.ProtocolDecoder;
052import org.apache.mina.filter.codec.ProtocolDecoderAdapter;
053import org.apache.mina.filter.codec.ProtocolDecoderOutput;
054import org.apache.mina.filter.codec.ProtocolEncoder;
055import org.apache.mina.filter.codec.ProtocolEncoderAdapter;
056import org.apache.mina.filter.codec.ProtocolEncoderOutput;
057import org.apache.mina.filter.executor.ExecutorFilter;
058import org.apache.mina.filter.statistic.ProfilerTimerFilter;
059import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
060import org.apache.mina.transport.socket.nio.NioSocketConnector;
061import org.junit.After;
062import org.junit.Before;
063import org.junit.Test;
064import org.slf4j.Logger;
065import org.slf4j.LoggerFactory;
066
067/**
068 * Tests {@link MdcInjectionFilter} in various scenarios.
069 *
070 * @author <a href="http://mina.apache.org">Apache MINA Project</a>
071 */
072public class MdcInjectionFilterTest {
073
074    static Logger LOGGER = LoggerFactory.getLogger(MdcInjectionFilterTest.class);
075
076    private static final int TIMEOUT = 5000;
077
078    private final MyAppender appender = new MyAppender();
079
080    private int port;
081
082    private NioSocketAcceptor acceptor;
083
084    private Level previousLevelRootLogger;
085
086    private ExecutorFilter executorFilter1;
087
088    private ExecutorFilter executorFilter2;
089
090    @Before
091    public void setUp() throws Exception {
092        // comment out next line if you want to see normal logging
093        org.apache.log4j.Logger.getRootLogger().removeAllAppenders();
094        previousLevelRootLogger = org.apache.log4j.Logger.getRootLogger().getLevel();
095        org.apache.log4j.Logger.getRootLogger().setLevel(Level.DEBUG);
096        org.apache.log4j.Logger.getRootLogger().addAppender(appender);
097        acceptor = new NioSocketAcceptor();
098    }
099
100    @After
101    public void tearDown() throws Exception {
102        acceptor.dispose(true);
103        org.apache.log4j.Logger.getRootLogger().setLevel(previousLevelRootLogger);
104
105        destroy(executorFilter1);
106        destroy(executorFilter2);
107
108        List<String> after = getThreadNames();
109
110        // give acceptor some time to shut down
111        Thread.sleep(50);
112        after = getThreadNames();
113
114        int count = 0;
115
116        // NOTE: this is *not* intended to be a permanent fix for this test-case.
117        // There used to be no API to block until the ExecutorService of AbstractIoService is terminated.
118        // The API exists now : dispose(true) so we should get rid of this code.
119
120        while (contains(after, "Nio") && count++ < 10) {
121            Thread.sleep(50);
122            after = getThreadNames();
123        }
124
125        while (contains(after, "pool") && count++ < 10) {
126            Thread.sleep(50);
127            after = getThreadNames();
128        }
129
130        // The problem is that we clear the events of the appender here, but it's possible that a thread from
131        // a previous test still generates events during the execution of the next test
132        appender.clear();
133    }
134
135    private void destroy(ExecutorFilter executorFilter) throws InterruptedException {
136        if (executorFilter != null) {
137            ExecutorService executor = (ExecutorService) executorFilter.getExecutor();
138            executor.shutdown();
139            while (!executor.isTerminated()) {
140                //System.out.println("Waiting for termination of " + executorFilter);  
141                executor.awaitTermination(10, TimeUnit.MILLISECONDS);
142            }
143        }
144    }
145
146    @Test
147    public void testSimpleChain() throws IOException, InterruptedException {
148        DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
149        chain.addFirst("mdc-injector", new MdcInjectionFilter());
150        chain.addLast("dummy", new DummyIoFilter());
151        chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
152        test(chain);
153    }
154
155    @Test
156    public void testExecutorFilterAtTheEnd() throws IOException, InterruptedException {
157        executorFilter1 = new ExecutorFilter();
158        DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
159        MdcInjectionFilter mdcInjectionFilter = new MdcInjectionFilter();
160        chain.addFirst("mdc-injector1", mdcInjectionFilter);
161        chain.addLast("dummy", new DummyIoFilter());
162        chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
163        chain.addLast("executor", executorFilter1);
164        chain.addLast("mdc-injector2", mdcInjectionFilter);
165        test(chain);
166    }
167
168    @Test
169    public void testExecutorFilterAtBeginning() throws IOException, InterruptedException {
170        executorFilter1 = new ExecutorFilter();
171        DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
172        MdcInjectionFilter mdcInjectionFilter = new MdcInjectionFilter();
173        chain.addLast("executor", executorFilter1);
174        chain.addLast("mdc-injector", mdcInjectionFilter);
175        chain.addLast("dummy", new DummyIoFilter());
176        chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
177        test(chain);
178    }
179
180    @Test
181    public void testExecutorFilterBeforeProtocol() throws IOException, InterruptedException {
182        executorFilter1 = new ExecutorFilter();
183        DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
184        MdcInjectionFilter mdcInjectionFilter = new MdcInjectionFilter();
185        chain.addLast("executor", executorFilter1);
186        chain.addLast("mdc-injector", mdcInjectionFilter);
187        chain.addLast("dummy", new DummyIoFilter());
188        chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
189        test(chain);
190    }
191
192    @Test
193    public void testMultipleFilters() throws IOException, InterruptedException {
194        executorFilter1 = new ExecutorFilter();
195        DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
196        MdcInjectionFilter mdcInjectionFilter = new MdcInjectionFilter();
197        chain.addLast("executor", executorFilter1);
198        chain.addLast("mdc-injector", mdcInjectionFilter);
199        chain.addLast("profiler", new ProfilerTimerFilter());
200        chain.addLast("dummy", new DummyIoFilter());
201        chain.addLast("logger", new LoggingFilter());
202        chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
203        test(chain);
204    }
205
206    @Test
207    public void testTwoExecutorFilters() throws IOException, InterruptedException {
208        DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
209        MdcInjectionFilter mdcInjectionFilter = new MdcInjectionFilter();
210        executorFilter1 = new ExecutorFilter();
211        executorFilter2 = new ExecutorFilter();
212        chain.addLast("executorFilter1", executorFilter1);
213        chain.addLast("mdc-injector1", mdcInjectionFilter);
214        chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
215        chain.addLast("dummy", new DummyIoFilter());
216        chain.addLast("executorFilter2", executorFilter2);
217        // add the MdcInjectionFilter instance after every ExecutorFilter
218        // it's important to use the same MdcInjectionFilter instance
219        chain.addLast("mdc-injector2", mdcInjectionFilter);
220        test(chain);
221    }
222
223    @Test
224    public void testOnlyRemoteAddress() throws IOException, InterruptedException {
225        DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
226        chain.addFirst("mdc-injector", new MdcInjectionFilter(MdcInjectionFilter.MdcKey.remoteAddress));
227        chain.addLast("dummy", new DummyIoFilter());
228        chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
229        SimpleIoHandler simpleIoHandler = new SimpleIoHandler();
230        acceptor.setHandler(simpleIoHandler);
231        acceptor.bind(new InetSocketAddress(0));
232        port = acceptor.getLocalAddress().getPort();
233        acceptor.setFilterChainBuilder(chain);
234        // create some clients
235        NioSocketConnector connector = new NioSocketConnector();
236        connector.setHandler(new IoHandlerAdapter());
237        connectAndWrite(connector, 0);
238        connectAndWrite(connector, 1);
239        // wait until Iohandler has received all events
240        simpleIoHandler.messageSentLatch.await();
241        simpleIoHandler.sessionIdleLatch.await();
242        simpleIoHandler.sessionClosedLatch.await();
243        connector.dispose(true);
244
245        // make a copy to prevent ConcurrentModificationException
246        List<LoggingEvent> events = new ArrayList<LoggingEvent>(appender.events);
247        // verify that all logging events have correct MDC
248        for (LoggingEvent event : events) {
249            if (event.getLoggerName().startsWith("org.apache.mina.core.service.AbstractIoService")) {
250                continue;
251            }
252            for (MdcInjectionFilter.MdcKey mdcKey : MdcInjectionFilter.MdcKey.values()) {
253                String key = mdcKey.name();
254                Object value = event.getMDC(key);
255                if (mdcKey == MdcInjectionFilter.MdcKey.remoteAddress) {
256                    assertNotNull("MDC[remoteAddress] not set for [" + event.getMessage() + "]", value);
257                } else {
258                    assertNull("MDC[" + key + "] set for [" + event.getMessage() + "]", value);
259                }
260            }
261        }
262    }
263
264    private void test(DefaultIoFilterChainBuilder chain) throws IOException, InterruptedException {
265        // configure the server
266        SimpleIoHandler simpleIoHandler = new SimpleIoHandler();
267        acceptor.setHandler(simpleIoHandler);
268        acceptor.bind(new InetSocketAddress(0));
269        port = acceptor.getLocalAddress().getPort();
270        acceptor.setFilterChainBuilder(chain);
271        // create some clients
272        NioSocketConnector connector = new NioSocketConnector();
273        connector.setHandler(new IoHandlerAdapter());
274        SocketAddress remoteAddressClients[] = new SocketAddress[2];
275        remoteAddressClients[0] = connectAndWrite(connector, 0);
276        remoteAddressClients[1] = connectAndWrite(connector, 1);
277        // wait until Iohandler has received all events
278        simpleIoHandler.messageSentLatch.await();
279        simpleIoHandler.sessionIdleLatch.await();
280        simpleIoHandler.sessionClosedLatch.await();
281        connector.dispose(true);
282
283        // make a copy to prevent ConcurrentModificationException
284        List<LoggingEvent> events = new ArrayList<LoggingEvent>(appender.events);
285
286        Set<String> loggersToCheck = new HashSet<String>();
287        loggersToCheck.add(MdcInjectionFilterTest.class.getName());
288        loggersToCheck.add(ProtocolCodecFilter.class.getName());
289        loggersToCheck.add(LoggingFilter.class.getName());
290
291        // verify that all logging events have correct MDC
292        for (LoggingEvent event : events) {
293
294            if (loggersToCheck.contains(event.getLoggerName())) {
295                Object remoteAddress = event.getMDC("remoteAddress");
296                assertNotNull("MDC[remoteAddress] not set for [" + event.getMessage() + "]", remoteAddress);
297                assertNotNull("MDC[remotePort] not set for [" + event.getMessage() + "]", event.getMDC("remotePort"));
298                assertEquals("every event should have MDC[handlerClass]", SimpleIoHandler.class.getName(),
299                        event.getMDC("handlerClass"));
300            }
301        }
302        // assert we have received all expected logging events for each client
303        for (int i = 0; i < remoteAddressClients.length; i++) {
304            SocketAddress remoteAddressClient = remoteAddressClients[i];
305            assertEventExists(events, "sessionCreated", remoteAddressClient, null);
306            assertEventExists(events, "sessionOpened", remoteAddressClient, null);
307            assertEventExists(events, "decode", remoteAddressClient, null);
308            assertEventExists(events, "messageReceived-1", remoteAddressClient, null);
309            assertEventExists(events, "messageReceived-2", remoteAddressClient, "user-" + i);
310            assertEventExists(events, "encode", remoteAddressClient, null);
311            assertEventExists(events, "exceptionCaught", remoteAddressClient, "user-" + i);
312            assertEventExists(events, "messageSent-1", remoteAddressClient, "user-" + i);
313            assertEventExists(events, "messageSent-2", remoteAddressClient, null);
314            assertEventExists(events, "sessionIdle", remoteAddressClient, "user-" + i);
315            assertEventExists(events, "sessionClosed", remoteAddressClient, "user-" + i);
316            assertEventExists(events, "sessionClosed", remoteAddressClient, "user-" + i);
317            assertEventExists(events, "DummyIoFilter.sessionOpened", remoteAddressClient, "user-" + i);
318        }
319    }
320
321    private SocketAddress connectAndWrite(NioSocketConnector connector, int clientNr) {
322        ConnectFuture connectFuture = connector.connect(new InetSocketAddress("localhost", port));
323        connectFuture.awaitUninterruptibly(TIMEOUT);
324        IoBuffer message = IoBuffer.allocate(4).putInt(clientNr).flip();
325        IoSession session = connectFuture.getSession();
326        session.write(message).awaitUninterruptibly(TIMEOUT);
327        return session.getLocalAddress();
328    }
329
330    private void assertEventExists(List<LoggingEvent> events, String message, SocketAddress address, String user) {
331        InetSocketAddress remoteAddress = (InetSocketAddress) address;
332        for (LoggingEvent event : events) {
333            if (event.getMessage().equals(message) && event.getMDC("remoteAddress").equals(remoteAddress.toString())
334                    && event.getMDC("remoteIp").equals(remoteAddress.getAddress().getHostAddress())
335                    && event.getMDC("remotePort").equals(remoteAddress.getPort() + "")) {
336                if (user == null && event.getMDC("user") == null) {
337                    return;
338                }
339                if (user != null && user.equals(event.getMDC("user"))) {
340                    return;
341                }
342                return;
343            }
344        }
345        fail("No LoggingEvent found from [" + remoteAddress + "] with message [" + message + "]");
346    }
347
348    private static class SimpleIoHandler extends IoHandlerAdapter {
349        CountDownLatch sessionIdleLatch = new CountDownLatch(2);
350
351        CountDownLatch sessionClosedLatch = new CountDownLatch(2);
352
353        CountDownLatch messageSentLatch = new CountDownLatch(2);
354
355        /**
356         * Default constructor
357         */
358        public SimpleIoHandler() {
359            super();
360        }
361
362        @Override
363        public void sessionCreated(IoSession session) throws Exception {
364            LOGGER.info("sessionCreated");
365            session.getConfig().setIdleTime(IdleStatus.BOTH_IDLE, 1);
366        }
367
368        @Override
369        public void sessionOpened(IoSession session) throws Exception {
370            LOGGER.info("sessionOpened");
371        }
372
373        @Override
374        public void sessionClosed(IoSession session) throws Exception {
375            LOGGER.info("sessionClosed");
376            sessionClosedLatch.countDown();
377        }
378
379        @Override
380        public void sessionIdle(IoSession session, IdleStatus status) throws Exception {
381            LOGGER.info("sessionIdle");
382            sessionIdleLatch.countDown();
383            session.close(true);
384        }
385
386        @Override
387        public void exceptionCaught(IoSession session, Throwable cause) throws Exception {
388            LOGGER.info("exceptionCaught", cause);
389        }
390
391        @Override
392        public void messageReceived(IoSession session, Object message) throws Exception {
393            LOGGER.info("messageReceived-1");
394            // adding a custom property to the context
395            String user = "user-" + message;
396            MdcInjectionFilter.setProperty(session, "user", user);
397            LOGGER.info("messageReceived-2");
398            session.getService().broadcast(message);
399            throw new RuntimeException("just a test, forcing exceptionCaught");
400        }
401
402        @Override
403        public void messageSent(IoSession session, Object message) throws Exception {
404            LOGGER.info("messageSent-1");
405            MdcInjectionFilter.removeProperty(session, "user");
406            LOGGER.info("messageSent-2");
407            messageSentLatch.countDown();
408        }
409    }
410
411    private static class DummyProtocolCodecFactory implements ProtocolCodecFactory {
412        /**
413         * Default constructor
414         */
415        public DummyProtocolCodecFactory() {
416            super();
417        }
418
419        public ProtocolEncoder getEncoder(IoSession session) throws Exception {
420            return new ProtocolEncoderAdapter() {
421                public void encode(IoSession session, Object message, ProtocolEncoderOutput out) throws Exception {
422                    LOGGER.info("encode");
423                    IoBuffer buffer = IoBuffer.allocate(4).putInt(123).flip();
424                    out.write(buffer);
425                }
426            };
427        }
428
429        public ProtocolDecoder getDecoder(IoSession session) throws Exception {
430            return new ProtocolDecoderAdapter() {
431                public void decode(IoSession session, IoBuffer in, ProtocolDecoderOutput out) throws Exception {
432                    if (in.remaining() >= 4) {
433                        int value = in.getInt();
434                        LOGGER.info("decode");
435                        out.write(value);
436                    }
437                }
438            };
439        }
440    }
441
442    private static class MyAppender extends AppenderSkeleton {
443        List<LoggingEvent> events = Collections.synchronizedList(new ArrayList<LoggingEvent>());
444
445        /**
446         * Default constructor
447         */
448        public MyAppender() {
449            super();
450        }
451
452        public void clear() {
453            events.clear();
454        }
455
456        @Override
457        protected void append(final LoggingEvent loggingEvent) {
458            loggingEvent.getMDCCopy();
459            events.add(loggingEvent);
460        }
461
462        public boolean requiresLayout() {
463            return false;
464        }
465
466        public void close() {
467            // Do nothing
468        }
469    }
470
471    static class DummyIoFilter extends IoFilterAdapter {
472        @Override
473        public void sessionOpened(NextFilter nextFilter, IoSession session) throws Exception {
474            LOGGER.info("DummyIoFilter.sessionOpened");
475            nextFilter.sessionOpened(session);
476        }
477    }
478
479    private List<String> getThreadNames() {
480        List<String> list = new ArrayList<String>();
481        int active = Thread.activeCount();
482        Thread[] threads = new Thread[active];
483        Thread.enumerate(threads);
484        for (Thread thread : threads) {
485            try {
486                String name = thread.getName();
487                list.add(name);
488            } catch (NullPointerException ignore) {
489            }
490        }
491        return list;
492    }
493
494    private boolean contains(List<String> list, String search) {
495        for (String s : list) {
496            if (s.contains(search)) {
497                return true;
498            }
499        }
500        return false;
501    }
502}