1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.codec.demux;
21
22 import java.util.Map;
23 import java.util.Set;
24 import java.util.concurrent.ConcurrentHashMap;
25
26 import org.apache.mina.core.session.AttributeKey;
27 import org.apache.mina.core.session.IoSession;
28 import org.apache.mina.core.session.UnknownMessageTypeException;
29 import org.apache.mina.filter.codec.ProtocolEncoder;
30 import org.apache.mina.filter.codec.ProtocolEncoderOutput;
31 import org.apache.mina.util.CopyOnWriteMap;
32 import org.apache.mina.util.IdentityHashSet;
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48 public class DemuxingProtocolEncoder implements ProtocolEncoder {
49
50 private final AttributeKey STATE = new AttributeKey(getClass(), "state");
51
52 @SuppressWarnings("unchecked")
53 private final Map<Class<?>, MessageEncoderFactory> type2encoderFactory = new CopyOnWriteMap<Class<?>, MessageEncoderFactory>();
54
55 private static final Class<?>[] EMPTY_PARAMS = new Class[0];
56
57 public DemuxingProtocolEncoder() {
58
59 }
60
61 @SuppressWarnings("unchecked")
62 public void addMessageEncoder(Class<?> messageType, Class<? extends MessageEncoder> encoderClass) {
63 if (encoderClass == null) {
64 throw new IllegalArgumentException("encoderClass");
65 }
66
67 try {
68 encoderClass.getConstructor(EMPTY_PARAMS);
69 } catch (NoSuchMethodException e) {
70 throw new IllegalArgumentException(
71 "The specified class doesn't have a public default constructor.");
72 }
73
74 boolean registered = false;
75 if (MessageEncoder.class.isAssignableFrom(encoderClass)) {
76 addMessageEncoder(messageType, new DefaultConstructorMessageEncoderFactory(encoderClass));
77 registered = true;
78 }
79
80 if (!registered) {
81 throw new IllegalArgumentException(
82 "Unregisterable type: " + encoderClass);
83 }
84 }
85
86 @SuppressWarnings("unchecked")
87 public <T> void addMessageEncoder(Class<T> messageType, MessageEncoder<? super T> encoder) {
88 addMessageEncoder(messageType, new SingletonMessageEncoderFactory(encoder));
89 }
90
91 public <T> void addMessageEncoder(Class<T> messageType, MessageEncoderFactory<? super T> factory) {
92 if (messageType == null) {
93 throw new IllegalArgumentException("messageType");
94 }
95
96 if (factory == null) {
97 throw new IllegalArgumentException("factory");
98 }
99
100 synchronized (type2encoderFactory) {
101 if (type2encoderFactory.containsKey(messageType)) {
102 throw new IllegalStateException(
103 "The specified message type (" + messageType.getName() + ") is registered already.");
104 }
105
106 type2encoderFactory.put(messageType, factory);
107 }
108 }
109
110 @SuppressWarnings("unchecked")
111 public void addMessageEncoder(Iterable<Class<?>> messageTypes, Class<? extends MessageEncoder> encoderClass) {
112 for (Class<?> messageType : messageTypes) {
113 addMessageEncoder(messageType, encoderClass);
114 }
115 }
116
117 public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes, MessageEncoder<? super T> encoder) {
118 for (Class<? extends T> messageType : messageTypes) {
119 addMessageEncoder(messageType, encoder);
120 }
121 }
122
123 public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes, MessageEncoderFactory<? super T> factory) {
124 for (Class<? extends T> messageType : messageTypes) {
125 addMessageEncoder(messageType, factory);
126 }
127 }
128
129
130
131
132 public void encode(IoSession session, Object message,
133 ProtocolEncoderOutput out) throws Exception {
134 State state = getState(session);
135 MessageEncoder<Object> encoder = findEncoder(state, message.getClass());
136 if (encoder != null) {
137 encoder.encode(session, message, out);
138 } else {
139 throw new UnknownMessageTypeException(
140 "No message encoder found for message: " + message);
141 }
142 }
143
144 protected MessageEncoder<Object> findEncoder(State state, Class<?> type) {
145 return findEncoder(state, type, null);
146 }
147
148 @SuppressWarnings("unchecked")
149 private MessageEncoder<Object> findEncoder(
150 State state, Class type, Set<Class> triedClasses) {
151 MessageEncoder encoder = null;
152
153 if (triedClasses != null && triedClasses.contains(type)) {
154 return null;
155 }
156
157
158
159
160 encoder = state.findEncoderCache.get(type);
161 if (encoder != null) {
162 return encoder;
163 }
164
165
166
167
168 encoder = state.type2encoder.get(type);
169
170 if (encoder == null) {
171
172
173
174
175 if (triedClasses == null) {
176 triedClasses = new IdentityHashSet<Class>();
177 }
178 triedClasses.add(type);
179
180 Class[] interfaces = type.getInterfaces();
181 for (Class element : interfaces) {
182 encoder = findEncoder(state, element, triedClasses);
183 if (encoder != null) {
184 break;
185 }
186 }
187 }
188
189 if (encoder == null) {
190
191
192
193
194
195 Class superclass = type.getSuperclass();
196 if (superclass != null) {
197 encoder = findEncoder(state, superclass);
198 }
199 }
200
201
202
203
204
205
206 if (encoder != null) {
207 state.findEncoderCache.put(type, encoder);
208 }
209
210 return encoder;
211 }
212
213
214
215
216 public void dispose(IoSession session) throws Exception {
217 session.removeAttribute(STATE);
218 }
219
220 private State getState(IoSession session) throws Exception {
221 State state = (State) session.getAttribute(STATE);
222 if (state == null) {
223 state = new State();
224 State oldState = (State) session.setAttributeIfAbsent(STATE, state);
225 if (oldState != null) {
226 state = oldState;
227 }
228 }
229 return state;
230 }
231
232 private class State {
233 @SuppressWarnings("unchecked")
234 private final Map<Class<?>, MessageEncoder> findEncoderCache = new ConcurrentHashMap<Class<?>, MessageEncoder>();
235
236 @SuppressWarnings("unchecked")
237 private final Map<Class<?>, MessageEncoder> type2encoder = new ConcurrentHashMap<Class<?>, MessageEncoder>();
238
239 @SuppressWarnings("unchecked")
240 private State() throws Exception {
241 for (Map.Entry<Class<?>, MessageEncoderFactory> e: type2encoderFactory.entrySet()) {
242 type2encoder.put(e.getKey(), e.getValue().getEncoder());
243 }
244 }
245 }
246
247 private static class SingletonMessageEncoderFactory<T> implements
248 MessageEncoderFactory<T> {
249 private final MessageEncoder<T> encoder;
250
251 private SingletonMessageEncoderFactory(MessageEncoder<T> encoder) {
252 if (encoder == null) {
253 throw new IllegalArgumentException("encoder");
254 }
255 this.encoder = encoder;
256 }
257
258 public MessageEncoder<T> getEncoder() {
259 return encoder;
260 }
261 }
262
263 private static class DefaultConstructorMessageEncoderFactory<T> implements
264 MessageEncoderFactory<T> {
265 private final Class<MessageEncoder<T>> encoderClass;
266
267 private DefaultConstructorMessageEncoderFactory(Class<MessageEncoder<T>> encoderClass) {
268 if (encoderClass == null) {
269 throw new IllegalArgumentException("encoderClass");
270 }
271
272 if (!MessageEncoder.class.isAssignableFrom(encoderClass)) {
273 throw new IllegalArgumentException(
274 "encoderClass is not assignable to MessageEncoder");
275 }
276 this.encoderClass = encoderClass;
277 }
278
279 public MessageEncoder<T> getEncoder() throws Exception {
280 return encoderClass.newInstance();
281 }
282 }
283 }