1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.codec.demux;
21
22 import java.util.HashMap;
23 import java.util.Map;
24 import java.util.Set;
25
26 import org.apache.mina.core.session.AttributeKey;
27 import org.apache.mina.core.session.IoSession;
28 import org.apache.mina.core.session.UnknownMessageTypeException;
29 import org.apache.mina.filter.codec.ProtocolEncoder;
30 import org.apache.mina.filter.codec.ProtocolEncoderOutput;
31 import org.apache.mina.util.CopyOnWriteMap;
32 import org.apache.mina.util.IdentityHashSet;
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49 public class DemuxingProtocolEncoder implements ProtocolEncoder {
50
51 private final AttributeKey STATE = new AttributeKey(getClass(), "state");
52
53 @SuppressWarnings("unchecked")
54 private final Map<Class<?>, MessageEncoderFactory> type2encoderFactory = new CopyOnWriteMap<Class<?>, MessageEncoderFactory>();
55
56 private static final Class<?>[] EMPTY_PARAMS = new Class[0];
57
58 public DemuxingProtocolEncoder() {
59 }
60
61 @SuppressWarnings("unchecked")
62 public void addMessageEncoder(Class<?> messageType, Class<? extends MessageEncoder> encoderClass) {
63 if (encoderClass == null) {
64 throw new NullPointerException("encoderClass");
65 }
66
67 try {
68 encoderClass.getConstructor(EMPTY_PARAMS);
69 } catch (NoSuchMethodException e) {
70 throw new IllegalArgumentException(
71 "The specified class doesn't have a public default constructor.");
72 }
73
74 boolean registered = false;
75 if (MessageEncoder.class.isAssignableFrom(encoderClass)) {
76 addMessageEncoder(messageType, new DefaultConstructorMessageEncoderFactory(encoderClass));
77 registered = true;
78 }
79
80 if (!registered) {
81 throw new IllegalArgumentException(
82 "Unregisterable type: " + encoderClass);
83 }
84 }
85
86 @SuppressWarnings("unchecked")
87 public <T> void addMessageEncoder(Class<T> messageType, MessageEncoder<? super T> encoder) {
88 addMessageEncoder(messageType, new SingletonMessageEncoderFactory(encoder));
89 }
90
91 public <T> void addMessageEncoder(Class<T> messageType, MessageEncoderFactory<? super T> factory) {
92 if (messageType == null) {
93 throw new NullPointerException("messageType");
94 }
95
96 if (factory == null) {
97 throw new NullPointerException("factory");
98 }
99
100 synchronized (type2encoderFactory) {
101 if (type2encoderFactory.containsKey(messageType)) {
102 throw new IllegalStateException(
103 "The specified message type (" + messageType.getName() + ") is registered already.");
104 }
105
106 type2encoderFactory.put(messageType, factory);
107 }
108 }
109
110 @SuppressWarnings("unchecked")
111 public void addMessageEncoder(Iterable<Class<?>> messageTypes, Class<? extends MessageEncoder> encoderClass) {
112 for (Class<?> messageType : messageTypes) {
113 addMessageEncoder(messageType, encoderClass);
114 }
115 }
116
117 public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes, MessageEncoder<? super T> encoder) {
118 for (Class<? extends T> messageType : messageTypes) {
119 addMessageEncoder(messageType, encoder);
120 }
121 }
122
123 public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes, MessageEncoderFactory<? super T> factory) {
124 for (Class<? extends T> messageType : messageTypes) {
125 addMessageEncoder(messageType, factory);
126 }
127 }
128
129 public void encode(IoSession session, Object message,
130 ProtocolEncoderOutput out) throws Exception {
131 State state = getState(session);
132 MessageEncoder<Object> encoder = findEncoder(state, message.getClass());
133 if (encoder != null) {
134 encoder.encode(session, message, out);
135 } else {
136 throw new UnknownMessageTypeException(
137 "No message encoder found for message: " + message);
138 }
139 }
140
141 protected MessageEncoder<Object> findEncoder(State state, Class<?> type) {
142 return findEncoder(state, type, null);
143 }
144
145 @SuppressWarnings("unchecked")
146 private MessageEncoder<Object> findEncoder(
147 State state, Class type, Set<Class> triedClasses) {
148 MessageEncoder encoder = null;
149
150 if (triedClasses != null && triedClasses.contains(type)) {
151 return null;
152 }
153
154
155
156
157 encoder = state.findEncoderCache.get(type);
158 if (encoder != null) {
159 return encoder;
160 }
161
162
163
164
165 encoder = state.type2encoder.get(type);
166
167 if (encoder == null) {
168
169
170
171
172 if (triedClasses == null) {
173 triedClasses = new IdentityHashSet<Class>();
174 }
175 triedClasses.add(type);
176
177 Class[] interfaces = type.getInterfaces();
178 for (Class element : interfaces) {
179 encoder = findEncoder(state, element, triedClasses);
180 if (encoder != null) {
181 break;
182 }
183 }
184 }
185
186 if (encoder == null) {
187
188
189
190
191
192 Class superclass = type.getSuperclass();
193 if (superclass != null) {
194 encoder = findEncoder(state, superclass);
195 }
196 }
197
198
199
200
201
202
203 if (encoder != null) {
204 state.findEncoderCache.put(type, encoder);
205 }
206
207 return encoder;
208 }
209
210 public void dispose(IoSession session) throws Exception {
211 session.removeAttribute(STATE);
212 }
213
214 private State getState(IoSession session) throws Exception {
215 State state = (State) session.getAttribute(STATE);
216 if (state == null) {
217 state = new State();
218 State oldState = (State) session.setAttributeIfAbsent(STATE, state);
219 if (oldState != null) {
220 state = oldState;
221 }
222 }
223 return state;
224 }
225
226 private class State {
227 @SuppressWarnings("unchecked")
228 private final Map<Class<?>, MessageEncoder> findEncoderCache = new HashMap<Class<?>, MessageEncoder>();
229
230 @SuppressWarnings("unchecked")
231 private final Map<Class<?>, MessageEncoder> type2encoder = new HashMap<Class<?>, MessageEncoder>();
232
233 @SuppressWarnings("unchecked")
234 private State() throws Exception {
235 for (Map.Entry<Class<?>, MessageEncoderFactory> e: type2encoderFactory.entrySet()) {
236 type2encoder.put(e.getKey(), e.getValue().getEncoder());
237 }
238 }
239 }
240
241 private static class SingletonMessageEncoderFactory<T> implements
242 MessageEncoderFactory<T> {
243 private final MessageEncoder<T> encoder;
244
245 private SingletonMessageEncoderFactory(MessageEncoder<T> encoder) {
246 if (encoder == null) {
247 throw new NullPointerException("encoder");
248 }
249 this.encoder = encoder;
250 }
251
252 public MessageEncoder<T> getEncoder() {
253 return encoder;
254 }
255 }
256
257 private static class DefaultConstructorMessageEncoderFactory<T> implements
258 MessageEncoderFactory<T> {
259 private final Class<MessageEncoder<T>> encoderClass;
260
261 private DefaultConstructorMessageEncoderFactory(Class<MessageEncoder<T>> encoderClass) {
262 if (encoderClass == null) {
263 throw new NullPointerException("encoderClass");
264 }
265
266 if (!MessageEncoder.class.isAssignableFrom(encoderClass)) {
267 throw new IllegalArgumentException(
268 "encoderClass is not assignable to MessageEncoder");
269 }
270 this.encoderClass = encoderClass;
271 }
272
273 public MessageEncoder<T> getEncoder() throws Exception {
274 return encoderClass.newInstance();
275 }
276 }
277 }