001/*
002 *  Licensed to the Apache Software Foundation (ASF) under one
003 *  or more contributor license agreements.  See the NOTICE file
004 *  distributed with this work for additional information
005 *  regarding copyright ownership.  The ASF licenses this file
006 *  to you under the Apache License, Version 2.0 (the
007 *  "License"); you may not use this file except in compliance
008 *  with the License.  You may obtain a copy of the License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 *  Unless required by applicable law or agreed to in writing,
013 *  software distributed under the License is distributed on an
014 *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 *  KIND, either express or implied.  See the License for the
016 *  specific language governing permissions and limitations
017 *  under the License.
018 *
019 */
020package org.apache.mina.filter.codec.demux;
021
022import java.util.Map;
023import java.util.Set;
024import java.util.concurrent.ConcurrentHashMap;
025
026import org.apache.mina.core.session.AttributeKey;
027import org.apache.mina.core.session.IoSession;
028import org.apache.mina.core.session.UnknownMessageTypeException;
029import org.apache.mina.filter.codec.ProtocolEncoder;
030import org.apache.mina.filter.codec.ProtocolEncoderOutput;
031import org.apache.mina.util.CopyOnWriteMap;
032import org.apache.mina.util.IdentityHashSet;
033
034/**
035 * A composite {@link ProtocolEncoder} that demultiplexes incoming message
036 * encoding requests into an appropriate {@link MessageEncoder}.
037 *
038 * <h2>Disposing resources acquired by {@link MessageEncoder}</h2>
039 * <p>
040 * Override {@link #dispose(IoSession)} method. Please don't forget to call
041 * <tt>super.dispose()</tt>.
042 *
043 * @author <a href="http://mina.apache.org">Apache MINA Project</a>
044 *
045 * @see MessageEncoderFactory
046 * @see MessageEncoder
047 */
048public class DemuxingProtocolEncoder implements ProtocolEncoder {
049
050    private final AttributeKey STATE = new AttributeKey(getClass(), "state");
051
052    @SuppressWarnings("rawtypes")
053    private final Map<Class<?>, MessageEncoderFactory> type2encoderFactory = new CopyOnWriteMap<Class<?>, MessageEncoderFactory>();
054
055    private static final Class<?>[] EMPTY_PARAMS = new Class[0];
056
057    public DemuxingProtocolEncoder() {
058        // Do nothing
059    }
060
061    @SuppressWarnings({ "rawtypes", "unchecked" })
062    public void addMessageEncoder(Class<?> messageType, Class<? extends MessageEncoder> encoderClass) {
063        if (encoderClass == null) {
064            throw new IllegalArgumentException("encoderClass");
065        }
066
067        try {
068            encoderClass.getConstructor(EMPTY_PARAMS);
069        } catch (NoSuchMethodException e) {
070            throw new IllegalArgumentException("The specified class doesn't have a public default constructor.");
071        }
072
073        boolean registered = false;
074        if (MessageEncoder.class.isAssignableFrom(encoderClass)) {
075            addMessageEncoder(messageType, new DefaultConstructorMessageEncoderFactory(encoderClass));
076            registered = true;
077        }
078
079        if (!registered) {
080            throw new IllegalArgumentException("Unregisterable type: " + encoderClass);
081        }
082    }
083
084    @SuppressWarnings({ "unchecked", "rawtypes" })
085    public <T> void addMessageEncoder(Class<T> messageType, MessageEncoder<? super T> encoder) {
086        addMessageEncoder(messageType, new SingletonMessageEncoderFactory(encoder));
087    }
088
089    public <T> void addMessageEncoder(Class<T> messageType, MessageEncoderFactory<? super T> factory) {
090        if (messageType == null) {
091            throw new IllegalArgumentException("messageType");
092        }
093
094        if (factory == null) {
095            throw new IllegalArgumentException("factory");
096        }
097
098        synchronized (type2encoderFactory) {
099            if (type2encoderFactory.containsKey(messageType)) {
100                throw new IllegalStateException("The specified message type (" + messageType.getName()
101                        + ") is registered already.");
102            }
103
104            type2encoderFactory.put(messageType, factory);
105        }
106    }
107
108    @SuppressWarnings("rawtypes")
109    public void addMessageEncoder(Iterable<Class<?>> messageTypes, Class<? extends MessageEncoder> encoderClass) {
110        for (Class<?> messageType : messageTypes) {
111            addMessageEncoder(messageType, encoderClass);
112        }
113    }
114
115    public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes, MessageEncoder<? super T> encoder) {
116        for (Class<? extends T> messageType : messageTypes) {
117            addMessageEncoder(messageType, encoder);
118        }
119    }
120
121    public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes,
122            MessageEncoderFactory<? super T> factory) {
123        for (Class<? extends T> messageType : messageTypes) {
124            addMessageEncoder(messageType, factory);
125        }
126    }
127
128    /**
129     * {@inheritDoc}
130     */
131    public void encode(IoSession session, Object message, ProtocolEncoderOutput out) throws Exception {
132        State state = getState(session);
133        MessageEncoder<Object> encoder = findEncoder(state, message.getClass());
134        if (encoder != null) {
135            encoder.encode(session, message, out);
136        } else {
137            throw new UnknownMessageTypeException("No message encoder found for message: " + message);
138        }
139    }
140
141    protected MessageEncoder<Object> findEncoder(State state, Class<?> type) {
142        return findEncoder(state, type, null);
143    }
144
145    @SuppressWarnings("unchecked")
146    private MessageEncoder<Object> findEncoder(State state, Class<?> type, Set<Class<?>> triedClasses) {
147        @SuppressWarnings("rawtypes")
148        MessageEncoder encoder = null;
149
150        if (triedClasses != null && triedClasses.contains(type)) {
151            return null;
152        }
153
154        /*
155         * Try the cache first.
156         */
157        encoder = state.findEncoderCache.get(type);
158
159        if (encoder != null) {
160            return encoder;
161        }
162
163        /*
164         * Try the registered encoders for an immediate match.
165         */
166        encoder = state.type2encoder.get(type);
167
168        if (encoder == null) {
169            /*
170             * No immediate match could be found. Search the type's interfaces.
171             */
172
173            if (triedClasses == null) {
174                triedClasses = new IdentityHashSet<Class<?>>();
175            }
176
177            triedClasses.add(type);
178
179            Class<?>[] interfaces = type.getInterfaces();
180
181            for (Class<?> element : interfaces) {
182                encoder = findEncoder(state, element, triedClasses);
183
184                if (encoder != null) {
185                    break;
186                }
187            }
188        }
189
190        if (encoder == null) {
191            /*
192             * No match in type's interfaces could be found. Search the
193             * superclass.
194             */
195
196            Class<?> superclass = type.getSuperclass();
197
198            if (superclass != null) {
199                encoder = findEncoder(state, superclass);
200            }
201        }
202
203        /*
204         * Make sure the encoder is added to the cache. By updating the cache
205         * here all the types (superclasses and interfaces) in the path which
206         * led to a match will be cached along with the immediate message type.
207         */
208        if (encoder != null) {
209            state.findEncoderCache.put(type, encoder);
210            MessageEncoder<Object> tmpEncoder = state.findEncoderCache.putIfAbsent(type, encoder);
211
212            if (tmpEncoder != null) {
213                encoder = tmpEncoder;
214            }
215        }
216
217        return encoder;
218    }
219
220    /**
221     * {@inheritDoc}
222     */
223    public void dispose(IoSession session) throws Exception {
224        session.removeAttribute(STATE);
225    }
226
227    private State getState(IoSession session) throws Exception {
228        State state = (State) session.getAttribute(STATE);
229        if (state == null) {
230            state = new State();
231            State oldState = (State) session.setAttributeIfAbsent(STATE, state);
232            if (oldState != null) {
233                state = oldState;
234            }
235        }
236        return state;
237    }
238
239    private class State {
240        @SuppressWarnings("rawtypes")
241        private final ConcurrentHashMap<Class<?>, MessageEncoder> findEncoderCache = new ConcurrentHashMap<Class<?>, MessageEncoder>();
242
243        @SuppressWarnings("rawtypes")
244        private final Map<Class<?>, MessageEncoder> type2encoder = new ConcurrentHashMap<Class<?>, MessageEncoder>();
245
246        @SuppressWarnings("rawtypes")
247        private State() throws Exception {
248            for (Map.Entry<Class<?>, MessageEncoderFactory> e : type2encoderFactory.entrySet()) {
249                type2encoder.put(e.getKey(), e.getValue().getEncoder());
250            }
251        }
252    }
253
254    private static class SingletonMessageEncoderFactory<T> implements MessageEncoderFactory<T> {
255        private final MessageEncoder<T> encoder;
256
257        private SingletonMessageEncoderFactory(MessageEncoder<T> encoder) {
258            if (encoder == null) {
259                throw new IllegalArgumentException("encoder");
260            }
261            this.encoder = encoder;
262        }
263
264        public MessageEncoder<T> getEncoder() {
265            return encoder;
266        }
267    }
268
269    private static class DefaultConstructorMessageEncoderFactory<T> implements MessageEncoderFactory<T> {
270        private final Class<MessageEncoder<T>> encoderClass;
271
272        private DefaultConstructorMessageEncoderFactory(Class<MessageEncoder<T>> encoderClass) {
273            if (encoderClass == null) {
274                throw new IllegalArgumentException("encoderClass");
275            }
276
277            if (!MessageEncoder.class.isAssignableFrom(encoderClass)) {
278                throw new IllegalArgumentException("encoderClass is not assignable to MessageEncoder");
279            }
280            this.encoderClass = encoderClass;
281        }
282
283        public MessageEncoder<T> getEncoder() throws Exception {
284            return encoderClass.newInstance();
285        }
286    }
287}