1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.logging.log4j.audit.rest;
18
19 import java.io.IOException;
20 import java.text.DecimalFormat;
21 import java.util.Enumeration;
22
23 import javax.servlet.Filter;
24 import javax.servlet.FilterChain;
25 import javax.servlet.FilterConfig;
26 import javax.servlet.ServletException;
27 import javax.servlet.ServletRequest;
28 import javax.servlet.ServletResponse;
29 import javax.servlet.http.HttpServletRequest;
30 import javax.servlet.http.HttpServletResponse;
31
32 import org.apache.logging.log4j.LogManager;
33 import org.apache.logging.log4j.Logger;
34 import org.apache.logging.log4j.ThreadContext;
35 import org.apache.logging.log4j.audit.request.ChainedMapping;
36 import org.apache.logging.log4j.audit.request.RequestContextMapping;
37 import org.apache.logging.log4j.audit.request.RequestContextMappings;
38
39
40
41
42 public class RequestContextFilter implements Filter {
43
44 private static final Logger logger = LogManager.getLogger(RequestContextFilter.class);
45 private final Class<?> requestContextClass;
46 private RequestContextMappings mappings;
47
48 public RequestContextFilter() {
49 requestContextClass = null;
50 }
51
52 public RequestContextFilter(Class<?> clazz) {
53 requestContextClass = clazz;
54 }
55
56
57
58
59
60 @Override
61 public void init(FilterConfig filterConfig) throws ServletException {
62 if (requestContextClass != null) {
63 mappings = new RequestContextMappings(requestContextClass);
64 } else {
65 String requestContextClassName = filterConfig.getInitParameter("requestContextClass");
66 if (requestContextClassName == null) {
67 logger.error("No RequestContext class name was provided");
68 throw new IllegalArgumentException("No RequestContext class name provided");
69 }
70 mappings = new RequestContextMappings(requestContextClassName);
71 }
72 }
73
74
75
76
77 public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
78 throws IOException, ServletException {
79 if (servletRequest instanceof HttpServletRequest) {
80 HttpServletRequest request = (HttpServletRequest) servletRequest;
81 HttpServletResponse response = (HttpServletResponse) servletResponse;
82 logger.info("Starting request {}" + request.getRequestURI());
83 try {
84 Enumeration headers = request.getHeaderNames();
85 while (headers.hasMoreElements()) {
86 String name = (String) headers.nextElement();
87 RequestContextMapping mapping = mappings.getMappingByHeader(name);
88 logger.debug("Got Mapping:{} for Header:{}", mapping, name);
89 if (mapping != null) {
90 if (mapping.isChained()) {
91 ThreadContext.put(mapping.getChainKey(), request.getHeader(name));
92 logger.debug("Setting Context Key:{} with value:{}", mapping.getChainKey(), request.getHeader(name));
93 String value = ((ChainedMapping)mapping).getSupplier().get();
94 ThreadContext.put(mapping.getFieldName(), value);
95 logger.debug("Setting Context Key:{} with value:{}", mapping.getFieldName(), value);
96 } else {
97 ThreadContext.put(mapping.getFieldName(), request.getHeader(name));
98 logger.debug("Setting Context Key:{} with value:{}", mapping.getFieldName(), request.getHeader(name));
99 }
100 }
101 }
102 long start = System.nanoTime();
103 filterChain.doFilter(servletRequest, servletResponse);
104 long elapsed = System.nanoTime() - start;
105 StringBuilder sb = new StringBuilder("Request ").append(request.getRequestURI()).append(" completed in ");
106 ElapsedUtil.addElapsed(elapsed, sb);
107 logger.info(sb.toString());
108 } catch (Throwable e) {
109 logger.error("Application cascaded error", e);
110 response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
111 } finally {
112 ThreadContext.clearMap();
113 }
114 }
115 }
116
117 public void destroy() {
118 }
119 }