1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.codec.demux;
21
22 import java.util.Map;
23 import java.util.Set;
24 import java.util.concurrent.ConcurrentHashMap;
25
26 import org.apache.mina.core.session.AttributeKey;
27 import org.apache.mina.core.session.IoSession;
28 import org.apache.mina.core.session.UnknownMessageTypeException;
29 import org.apache.mina.filter.codec.ProtocolEncoder;
30 import org.apache.mina.filter.codec.ProtocolEncoderOutput;
31 import org.apache.mina.util.CopyOnWriteMap;
32 import org.apache.mina.util.IdentityHashSet;
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48 public class DemuxingProtocolEncoder implements ProtocolEncoder {
49
50 private static final AttributeKeyn/AttributeKey.html#AttributeKey">AttributeKey STATE = new AttributeKey(DemuxingProtocolEncoder.class, "state");
51
52 @SuppressWarnings("rawtypes")
53 private final Map<Class<?>, MessageEncoderFactory> type2encoderFactory = new CopyOnWriteMap<>();
54
55 private static final Class<?>[] EMPTY_PARAMS = new Class[0];
56
57
58
59
60
61
62
63 @SuppressWarnings({ "rawtypes", "unchecked" })
64 public void addMessageEncoder(Class<?> messageType, Class<? extends MessageEncoder> encoderClass) {
65 if (encoderClass == null) {
66 throw new IllegalArgumentException("encoderClass");
67 }
68
69 try {
70 encoderClass.getConstructor(EMPTY_PARAMS);
71 } catch (NoSuchMethodException e) {
72 throw new IllegalArgumentException("The specified class doesn't have a public default constructor.");
73 }
74
75 boolean registered = false;
76
77 if (MessageEncoder.class.isAssignableFrom(encoderClass)) {
78 addMessageEncoder(messageType, new DefaultConstructorMessageEncoderFactory(encoderClass));
79 registered = true;
80 }
81
82 if (!registered) {
83 throw new IllegalArgumentException("Unregisterable type: " + encoderClass);
84 }
85 }
86
87
88
89
90
91
92
93
94 @SuppressWarnings({ "unchecked", "rawtypes" })
95 public <T> void addMessageEncoder(Class<T> messageType, MessageEncoder<? super T> encoder) {
96 addMessageEncoder(messageType, new SingletonMessageEncoderFactory(encoder));
97 }
98
99
100
101
102
103
104
105
106 public <T> void addMessageEncoder(Class<T> messageType, MessageEncoderFactory<? super T> factory) {
107 if (messageType == null) {
108 throw new IllegalArgumentException("messageType");
109 }
110
111 if (factory == null) {
112 throw new IllegalArgumentException("factory");
113 }
114
115 synchronized (type2encoderFactory) {
116 if (type2encoderFactory.containsKey(messageType)) {
117 throw new IllegalStateException("The specified message type (" + messageType.getName()
118 + ") is registered already.");
119 }
120
121 type2encoderFactory.put(messageType, factory);
122 }
123 }
124
125
126
127
128
129
130
131 @SuppressWarnings("rawtypes")
132 public void addMessageEncoder(Iterable<Class<?>> messageTypes, Class<? extends MessageEncoder> encoderClass) {
133 for (Class<?> messageType : messageTypes) {
134 addMessageEncoder(messageType, encoderClass);
135 }
136 }
137
138
139
140
141
142
143
144
145 public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes, MessageEncoder<? super T> encoder) {
146 for (Class<? extends T> messageType : messageTypes) {
147 addMessageEncoder(messageType, encoder);
148 }
149 }
150
151
152
153
154
155
156
157
158 public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes,
159 MessageEncoderFactory<? super T> factory) {
160 for (Class<? extends T> messageType : messageTypes) {
161 addMessageEncoder(messageType, factory);
162 }
163 }
164
165
166
167
168 @Override
169 public void encode(IoSession session, Object message, ProtocolEncoderOutput out) throws Exception {
170 State state = getState(session);
171 MessageEncoder<Object> encoder = findEncoder(state, message.getClass());
172 if (encoder != null) {
173 encoder.encode(session, message, out);
174 } else {
175 throw new UnknownMessageTypeException("No message encoder found for message: " + message);
176 }
177 }
178
179 protected MessageEncoder<Object> findEncoder(State state, Class<?> type) {
180 return findEncoder(state, type, null);
181 }
182
183 @SuppressWarnings("unchecked")
184 private MessageEncoder<Object> findEncoder(State state, Class<?> type, Set<Class<?>> triedClasses) {
185 @SuppressWarnings("rawtypes")
186 MessageEncoder encoder;
187
188 if (triedClasses != null && triedClasses.contains(type)) {
189 return null;
190 }
191
192
193
194
195 encoder = state.findEncoderCache.get(type);
196
197 if (encoder != null) {
198 return encoder;
199 }
200
201
202
203
204 encoder = state.type2encoder.get(type);
205
206 if (encoder == null) {
207
208
209
210
211 if (triedClasses == null) {
212 triedClasses = new IdentityHashSet<>();
213 }
214
215 triedClasses.add(type);
216
217 Class<?>[] interfaces = type.getInterfaces();
218
219 for (Class<?> element : interfaces) {
220 encoder = findEncoder(state, element, triedClasses);
221
222 if (encoder != null) {
223 break;
224 }
225 }
226 }
227
228 if (encoder == null) {
229
230
231
232
233
234 Class<?> superclass = type.getSuperclass();
235
236 if (superclass != null) {
237 encoder = findEncoder(state, superclass);
238 }
239 }
240
241
242
243
244
245
246 if (encoder != null) {
247 state.findEncoderCache.put(type, encoder);
248 MessageEncoder<Object> tmpEncoder = state.findEncoderCache.putIfAbsent(type, encoder);
249
250 if (tmpEncoder != null) {
251 encoder = tmpEncoder;
252 }
253 }
254
255 return encoder;
256 }
257
258
259
260
261 @Override
262 public void dispose(IoSession session) throws Exception {
263 session.removeAttribute(STATE);
264 }
265
266 private State getState(IoSession session) throws Exception {
267 State state = (State) session.getAttribute(STATE);
268 if (state == null) {
269 state = new State();
270 State oldState = (State) session.setAttributeIfAbsent(STATE, state);
271 if (oldState != null) {
272 state = oldState;
273 }
274 }
275 return state;
276 }
277
278 private class State {
279 @SuppressWarnings("rawtypes")
280 private final ConcurrentHashMap<Class<?>, MessageEncoder> findEncoderCache = new ConcurrentHashMap<>();
281
282 @SuppressWarnings("rawtypes")
283 private final Map<Class<?>, MessageEncoder> type2encoder = new ConcurrentHashMap<>();
284
285 @SuppressWarnings("rawtypes")
286 private State() throws Exception {
287 for (Map.Entry<Class<?>, MessageEncoderFactory> e : type2encoderFactory.entrySet()) {
288 type2encoder.put(e.getKey(), e.getValue().getEncoder());
289 }
290 }
291 }
292
293 private static class SingletonMessageEncoderFactory<T> implements MessageEncoderFactory<T> {
294 private final MessageEncoder<T> encoder;
295
296 private SingletonMessageEncoderFactory(MessageEncoder<T> encoder) {
297 if (encoder == null) {
298 throw new IllegalArgumentException("encoder");
299 }
300 this.encoder = encoder;
301 }
302
303
304
305
306 @Override
307 public MessageEncoder<T> getEncoder() {
308 return encoder;
309 }
310 }
311
312 private static class DefaultConstructorMessageEncoderFactory<T> implements MessageEncoderFactory<T> {
313 private final Class<MessageEncoder<T>> encoderClass;
314
315 private DefaultConstructorMessageEncoderFactory(Class<MessageEncoder<T>> encoderClass) {
316 if (encoderClass == null) {
317 throw new IllegalArgumentException("encoderClass");
318 }
319
320 if (!MessageEncoder.class.isAssignableFrom(encoderClass)) {
321 throw new IllegalArgumentException("encoderClass is not assignable to MessageEncoder");
322 }
323 this.encoderClass = encoderClass;
324 }
325
326
327
328
329 @Override
330 public MessageEncoder<T> getEncoder() throws Exception {
331 return encoderClass.newInstance();
332 }
333 }
334 }