001/*
002 *  Licensed to the Apache Software Foundation (ASF) under one
003 *  or more contributor license agreements.  See the NOTICE file
004 *  distributed with this work for additional information
005 *  regarding copyright ownership.  The ASF licenses this file
006 *  to you under the Apache License, Version 2.0 (the
007 *  "License"); you may not use this file except in compliance
008 *  with the License.  You may obtain a copy of the License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 *  Unless required by applicable law or agreed to in writing,
013 *  software distributed under the License is distributed on an
014 *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 *  KIND, either express or implied.  See the License for the
016 *  specific language governing permissions and limitations
017 *  under the License.
018 *
019 */
020package org.apache.mina.filter.codec.demux;
021
022import org.apache.mina.core.buffer.IoBuffer;
023import org.apache.mina.core.session.AttributeKey;
024import org.apache.mina.core.session.IoSession;
025import org.apache.mina.filter.codec.CumulativeProtocolDecoder;
026import org.apache.mina.filter.codec.ProtocolDecoder;
027import org.apache.mina.filter.codec.ProtocolDecoderException;
028import org.apache.mina.filter.codec.ProtocolDecoderOutput;
029
030/**
031 * A composite {@link ProtocolDecoder} that demultiplexes incoming {@link IoBuffer}
032 * decoding requests into an appropriate {@link MessageDecoder}.
033 * 
034 * <h2>Internal mechanism of {@link MessageDecoder} selection</h2>
035 * <ol>
036 *   <li>
037 *     {@link DemuxingProtocolDecoder} iterates the list of candidate
038 *     {@link MessageDecoder}s and calls {@link MessageDecoder#decodable(IoSession, IoBuffer)}.
039 *     Initially, all registered {@link MessageDecoder}s are candidates.
040 *   </li>
041 *   <li>
042 *     If {@link MessageDecoderResult#NOT_OK} is returned, it is removed from the candidate
043 *     list.
044 *   </li>
045 *   <li>
046 *     If {@link MessageDecoderResult#NEED_DATA} is returned, it is retained in the candidate
047 *     list, and its {@link MessageDecoder#decodable(IoSession, IoBuffer)} will be invoked
048 *     again when more data is received.
049 *   </li>
050 *   <li>
051 *     If {@link MessageDecoderResult#OK} is returned, {@link DemuxingProtocolDecoder}
052 *     found the right {@link MessageDecoder}.
053 *   </li>
054 *   <li>
055 *     If there's no candidate left, an exception is raised.  Otherwise, 
056 *     {@link DemuxingProtocolDecoder} will keep iterating the candidate list.
057 *   </li>
058 * </ol>
059 * 
060 * Please note that any change of position and limit of the specified {@link IoBuffer}
061 * in {@link MessageDecoder#decodable(IoSession, IoBuffer)} will be reverted back to its
062 * original value.
063 * <p>
064 * Once a {@link MessageDecoder} is selected, {@link DemuxingProtocolDecoder} calls
065 * {@link MessageDecoder#decode(IoSession, IoBuffer, ProtocolDecoderOutput)} continuously
066 * reading its return value:
067 * <ul>
068 *   <li>
069 *     {@link MessageDecoderResult#NOT_OK} - protocol violation; {@link ProtocolDecoderException}
070 *     is raised automatically.
071 *   </li>
072 *   <li>
073 *     {@link MessageDecoderResult#NEED_DATA} - needs more data to read the whole message;
074 *     {@link MessageDecoder#decode(IoSession, IoBuffer, ProtocolDecoderOutput)}
075 *     will be invoked again when more data is received.
076 *   </li>
077 *   <li>
078 *     {@link MessageDecoderResult#OK} - successfully decoded a message; the candidate list will
079 *     be reset and the selection process will start over.
080 *   </li>
081 * </ul>
082 *
083 * @author <a href="http://mina.apache.org">Apache MINA Project</a>
084 *
085 * @see MessageDecoderFactory
086 * @see MessageDecoder
087 */
088public class DemuxingProtocolDecoder extends CumulativeProtocolDecoder {
089
090    private final AttributeKey STATE = new AttributeKey(getClass(), "state");
091
092    private MessageDecoderFactory[] decoderFactories = new MessageDecoderFactory[0];
093
094    private static final Class<?>[] EMPTY_PARAMS = new Class[0];
095
096    public DemuxingProtocolDecoder() {
097        // Do nothing
098    }
099
100    public void addMessageDecoder(Class<? extends MessageDecoder> decoderClass) {
101        if (decoderClass == null) {
102            throw new IllegalArgumentException("decoderClass");
103        }
104
105        try {
106            decoderClass.getConstructor(EMPTY_PARAMS);
107        } catch (NoSuchMethodException e) {
108            throw new IllegalArgumentException("The specified class doesn't have a public default constructor.");
109        }
110
111        boolean registered = false;
112        if (MessageDecoder.class.isAssignableFrom(decoderClass)) {
113            addMessageDecoder(new DefaultConstructorMessageDecoderFactory(decoderClass));
114            registered = true;
115        }
116
117        if (!registered) {
118            throw new IllegalArgumentException("Unregisterable type: " + decoderClass);
119        }
120    }
121
122    public void addMessageDecoder(MessageDecoder decoder) {
123        addMessageDecoder(new SingletonMessageDecoderFactory(decoder));
124    }
125
126    public void addMessageDecoder(MessageDecoderFactory factory) {
127        if (factory == null) {
128            throw new IllegalArgumentException("factory");
129        }
130        MessageDecoderFactory[] decoderFactories = this.decoderFactories;
131        MessageDecoderFactory[] newDecoderFactories = new MessageDecoderFactory[decoderFactories.length + 1];
132        System.arraycopy(decoderFactories, 0, newDecoderFactories, 0, decoderFactories.length);
133        newDecoderFactories[decoderFactories.length] = factory;
134        this.decoderFactories = newDecoderFactories;
135    }
136
137    /**
138     * {@inheritDoc}
139     */
140    @Override
141    protected boolean doDecode(IoSession session, IoBuffer in, ProtocolDecoderOutput out) throws Exception {
142        State state = getState(session);
143
144        if (state.currentDecoder == null) {
145            MessageDecoder[] decoders = state.decoders;
146            int undecodables = 0;
147
148            for (int i = decoders.length - 1; i >= 0; i--) {
149                MessageDecoder decoder = decoders[i];
150                int limit = in.limit();
151                int pos = in.position();
152
153                MessageDecoderResult result;
154
155                try {
156                    result = decoder.decodable(session, in);
157                } finally {
158                    in.position(pos);
159                    in.limit(limit);
160                }
161
162                if (result == MessageDecoder.OK) {
163                    state.currentDecoder = decoder;
164                    break;
165                } else if (result == MessageDecoder.NOT_OK) {
166                    undecodables++;
167                } else if (result != MessageDecoder.NEED_DATA) {
168                    throw new IllegalStateException("Unexpected decode result (see your decodable()): " + result);
169                }
170            }
171
172            if (undecodables == decoders.length) {
173                // Throw an exception if all decoders cannot decode data.
174                String dump = in.getHexDump();
175                in.position(in.limit()); // Skip data
176                ProtocolDecoderException e = new ProtocolDecoderException("No appropriate message decoder: " + dump);
177                e.setHexdump(dump);
178                throw e;
179            }
180
181            if (state.currentDecoder == null) {
182                // Decoder is not determined yet (i.e. we need more data)
183                return false;
184            }
185        }
186
187        try {
188            MessageDecoderResult result = state.currentDecoder.decode(session, in, out);
189            if (result == MessageDecoder.OK) {
190                state.currentDecoder = null;
191                return true;
192            } else if (result == MessageDecoder.NEED_DATA) {
193                return false;
194            } else if (result == MessageDecoder.NOT_OK) {
195                state.currentDecoder = null;
196                throw new ProtocolDecoderException("Message decoder returned NOT_OK.");
197            } else {
198                state.currentDecoder = null;
199                throw new IllegalStateException("Unexpected decode result (see your decode()): " + result);
200            }
201        } catch (Exception e) {
202            state.currentDecoder = null;
203            throw e;
204        }
205    }
206
207    /**
208     * {@inheritDoc}
209     */
210    @Override
211    public void finishDecode(IoSession session, ProtocolDecoderOutput out) throws Exception {
212        super.finishDecode(session, out);
213        State state = getState(session);
214        MessageDecoder currentDecoder = state.currentDecoder;
215        if (currentDecoder == null) {
216            return;
217        }
218
219        currentDecoder.finishDecode(session, out);
220    }
221
222    /**
223     * {@inheritDoc}
224     */
225    @Override
226    public void dispose(IoSession session) throws Exception {
227        super.dispose(session);
228        session.removeAttribute(STATE);
229    }
230
231    private State getState(IoSession session) throws Exception {
232        State state = (State) session.getAttribute(STATE);
233
234        if (state == null) {
235            state = new State();
236            State oldState = (State) session.setAttributeIfAbsent(STATE, state);
237
238            if (oldState != null) {
239                state = oldState;
240            }
241        }
242
243        return state;
244    }
245
246    private class State {
247        private final MessageDecoder[] decoders;
248
249        private MessageDecoder currentDecoder;
250
251        private State() throws Exception {
252            MessageDecoderFactory[] decoderFactories = DemuxingProtocolDecoder.this.decoderFactories;
253            decoders = new MessageDecoder[decoderFactories.length];
254            for (int i = decoderFactories.length - 1; i >= 0; i--) {
255                decoders[i] = decoderFactories[i].getDecoder();
256            }
257        }
258    }
259
260    private static class SingletonMessageDecoderFactory implements MessageDecoderFactory {
261        private final MessageDecoder decoder;
262
263        private SingletonMessageDecoderFactory(MessageDecoder decoder) {
264            if (decoder == null) {
265                throw new IllegalArgumentException("decoder");
266            }
267            this.decoder = decoder;
268        }
269
270        public MessageDecoder getDecoder() {
271            return decoder;
272        }
273    }
274
275    private static class DefaultConstructorMessageDecoderFactory implements MessageDecoderFactory {
276        private final Class<?> decoderClass;
277
278        private DefaultConstructorMessageDecoderFactory(Class<?> decoderClass) {
279            if (decoderClass == null) {
280                throw new IllegalArgumentException("decoderClass");
281            }
282
283            if (!MessageDecoder.class.isAssignableFrom(decoderClass)) {
284                throw new IllegalArgumentException("decoderClass is not assignable to MessageDecoder");
285            }
286            this.decoderClass = decoderClass;
287        }
288
289        public MessageDecoder getDecoder() throws Exception {
290            return (MessageDecoder) decoderClass.newInstance();
291        }
292    }
293}