1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.logging;
21
22 import java.net.InetSocketAddress;
23 import java.util.EnumSet;
24 import java.util.HashSet;
25 import java.util.Map;
26 import java.util.Set;
27 import java.util.Arrays;
28 import java.util.concurrent.ConcurrentHashMap;
29
30 import org.apache.mina.core.filterchain.IoFilterEvent;
31 import org.apache.mina.core.session.AttributeKey;
32 import org.apache.mina.core.session.IoSession;
33 import org.apache.mina.filter.util.CommonEventFilter;
34 import org.slf4j.MDC;
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73 public class MdcInjectionFilter extends CommonEventFilter {
74
75 public enum MdcKey {
76 handlerClass, remoteAddress, localAddress, remoteIp, remotePort, localIp, localPort
77 }
78
79
80 private static final AttributeKey CONTEXT_KEY = new AttributeKey(MdcInjectionFilter.class, "context");
81
82 private ThreadLocal<Integer> callDepth = new ThreadLocal<Integer>() {
83 @Override
84 protected Integer initialValue() {
85 return 0;
86 }
87 };
88
89 private EnumSet<MdcKey> mdcKeys;
90
91
92
93
94
95
96
97
98 public MdcInjectionFilter(EnumSet<MdcKey> keys) {
99 this.mdcKeys = keys.clone();
100 }
101
102
103
104
105
106
107
108
109 public MdcInjectionFilter(MdcKey... keys) {
110 Set<MdcKey> keySet = new HashSet<MdcKey>(Arrays.asList(keys));
111 this.mdcKeys = EnumSet.copyOf(keySet);
112 }
113
114 public MdcInjectionFilter() {
115 this.mdcKeys = EnumSet.allOf(MdcKey.class);
116 }
117
118 @Override
119 protected void filter(IoFilterEvent event) throws Exception {
120
121
122 int currentCallDepth = callDepth.get();
123 callDepth.set(currentCallDepth + 1);
124 Map<String, String> context = getAndFillContext(event.getSession());
125
126 if (currentCallDepth == 0) {
127
128 for (Map.Entry<String, String> e : context.entrySet()) {
129 MDC.put(e.getKey(), e.getValue());
130 }
131 }
132
133 try {
134
135 event.fire();
136 } finally {
137 if (currentCallDepth == 0) {
138
139 for (String key : context.keySet()) {
140 MDC.remove(key);
141 }
142 callDepth.remove();
143 } else {
144 callDepth.set(currentCallDepth);
145 }
146 }
147 }
148
149 private Map<String, String> getAndFillContext(final IoSession session) {
150 Map<String, String> context = getContext(session);
151 if (context.isEmpty()) {
152 fillContext(session, context);
153 }
154 return context;
155 }
156
157 @SuppressWarnings("unchecked")
158 private static Map<String, String> getContext(final IoSession session) {
159 Map<String, String> context = (Map<String, String>) session.getAttribute(CONTEXT_KEY);
160 if (context == null) {
161 context = new ConcurrentHashMap<String, String>();
162 session.setAttribute(CONTEXT_KEY, context);
163 }
164 return context;
165 }
166
167
168
169
170
171
172
173 protected void fillContext(final IoSession session, final Map<String, String> context) {
174 if (mdcKeys.contains(MdcKey.handlerClass)) {
175 context.put(MdcKey.handlerClass.name(), session.getHandler().getClass().getName());
176 }
177 if (mdcKeys.contains(MdcKey.remoteAddress)) {
178 context.put(MdcKey.remoteAddress.name(), session.getRemoteAddress().toString());
179 }
180 if (mdcKeys.contains(MdcKey.localAddress)) {
181 context.put(MdcKey.localAddress.name(), session.getLocalAddress().toString());
182 }
183 if (session.getTransportMetadata().getAddressType() == InetSocketAddress.class) {
184 InetSocketAddress remoteAddress = (InetSocketAddress) session.getRemoteAddress();
185 InetSocketAddress localAddress = (InetSocketAddress) session.getLocalAddress();
186
187 if (mdcKeys.contains(MdcKey.remoteIp)) {
188 context.put(MdcKey.remoteIp.name(), remoteAddress.getAddress().getHostAddress());
189 }
190 if (mdcKeys.contains(MdcKey.remotePort)) {
191 context.put(MdcKey.remotePort.name(), String.valueOf(remoteAddress.getPort()));
192 }
193 if (mdcKeys.contains(MdcKey.localIp)) {
194 context.put(MdcKey.localIp.name(), localAddress.getAddress().getHostAddress());
195 }
196 if (mdcKeys.contains(MdcKey.localPort)) {
197 context.put(MdcKey.localPort.name(), String.valueOf(localAddress.getPort()));
198 }
199 }
200 }
201
202 public static String getProperty(IoSession session, String key) {
203 if (key == null) {
204 throw new IllegalArgumentException("key should not be null");
205 }
206
207 Map<String, String> context = getContext(session);
208 String answer = context.get(key);
209 if (answer != null) {
210 return answer;
211 }
212
213 return MDC.get(key);
214 }
215
216
217
218
219
220
221
222
223 public static void setProperty(IoSession session, String key, String value) {
224 if (key == null) {
225 throw new IllegalArgumentException("key should not be null");
226 }
227 if (value == null) {
228 removeProperty(session, key);
229 }
230 Map<String, String> context = getContext(session);
231 context.put(key, value);
232 MDC.put(key, value);
233 }
234
235 public static void removeProperty(IoSession session, String key) {
236 if (key == null) {
237 throw new IllegalArgumentException("key should not be null");
238 }
239 Map<String, String> context = getContext(session);
240 context.remove(key);
241 MDC.remove(key);
242 }
243 }