001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 * 019 */ 020package org.apache.mina.filter.codec.demux; 021 022import java.util.Map; 023import java.util.Set; 024import java.util.concurrent.ConcurrentHashMap; 025 026import org.apache.mina.core.session.AttributeKey; 027import org.apache.mina.core.session.IoSession; 028import org.apache.mina.core.session.UnknownMessageTypeException; 029import org.apache.mina.filter.codec.ProtocolEncoder; 030import org.apache.mina.filter.codec.ProtocolEncoderOutput; 031import org.apache.mina.util.CopyOnWriteMap; 032import org.apache.mina.util.IdentityHashSet; 033 034/** 035 * A composite {@link ProtocolEncoder} that demultiplexes incoming message 036 * encoding requests into an appropriate {@link MessageEncoder}. 037 * 038 * <h2>Disposing resources acquired by {@link MessageEncoder}</h2> 039 * <p> 040 * Override {@link #dispose(IoSession)} method. Please don't forget to call 041 * <tt>super.dispose()</tt>. 042 * 043 * @author <a href="http://mina.apache.org">Apache MINA Project</a> 044 * 045 * @see MessageEncoderFactory 046 * @see MessageEncoder 047 */ 048public class DemuxingProtocolEncoder implements ProtocolEncoder { 049 050 private final AttributeKey STATE = new AttributeKey(getClass(), "state"); 051 052 @SuppressWarnings("rawtypes") 053 private final Map<Class<?>, MessageEncoderFactory> type2encoderFactory = new CopyOnWriteMap<Class<?>, MessageEncoderFactory>(); 054 055 private static final Class<?>[] EMPTY_PARAMS = new Class[0]; 056 057 public DemuxingProtocolEncoder() { 058 // Do nothing 059 } 060 061 @SuppressWarnings({ "rawtypes", "unchecked" }) 062 public void addMessageEncoder(Class<?> messageType, Class<? extends MessageEncoder> encoderClass) { 063 if (encoderClass == null) { 064 throw new IllegalArgumentException("encoderClass"); 065 } 066 067 try { 068 encoderClass.getConstructor(EMPTY_PARAMS); 069 } catch (NoSuchMethodException e) { 070 throw new IllegalArgumentException("The specified class doesn't have a public default constructor."); 071 } 072 073 boolean registered = false; 074 if (MessageEncoder.class.isAssignableFrom(encoderClass)) { 075 addMessageEncoder(messageType, new DefaultConstructorMessageEncoderFactory(encoderClass)); 076 registered = true; 077 } 078 079 if (!registered) { 080 throw new IllegalArgumentException("Unregisterable type: " + encoderClass); 081 } 082 } 083 084 @SuppressWarnings({ "unchecked", "rawtypes" }) 085 public <T> void addMessageEncoder(Class<T> messageType, MessageEncoder<? super T> encoder) { 086 addMessageEncoder(messageType, new SingletonMessageEncoderFactory(encoder)); 087 } 088 089 public <T> void addMessageEncoder(Class<T> messageType, MessageEncoderFactory<? super T> factory) { 090 if (messageType == null) { 091 throw new IllegalArgumentException("messageType"); 092 } 093 094 if (factory == null) { 095 throw new IllegalArgumentException("factory"); 096 } 097 098 synchronized (type2encoderFactory) { 099 if (type2encoderFactory.containsKey(messageType)) { 100 throw new IllegalStateException("The specified message type (" + messageType.getName() 101 + ") is registered already."); 102 } 103 104 type2encoderFactory.put(messageType, factory); 105 } 106 } 107 108 @SuppressWarnings("rawtypes") 109 public void addMessageEncoder(Iterable<Class<?>> messageTypes, Class<? extends MessageEncoder> encoderClass) { 110 for (Class<?> messageType : messageTypes) { 111 addMessageEncoder(messageType, encoderClass); 112 } 113 } 114 115 public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes, MessageEncoder<? super T> encoder) { 116 for (Class<? extends T> messageType : messageTypes) { 117 addMessageEncoder(messageType, encoder); 118 } 119 } 120 121 public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes, 122 MessageEncoderFactory<? super T> factory) { 123 for (Class<? extends T> messageType : messageTypes) { 124 addMessageEncoder(messageType, factory); 125 } 126 } 127 128 /** 129 * {@inheritDoc} 130 */ 131 public void encode(IoSession session, Object message, ProtocolEncoderOutput out) throws Exception { 132 State state = getState(session); 133 MessageEncoder<Object> encoder = findEncoder(state, message.getClass()); 134 if (encoder != null) { 135 encoder.encode(session, message, out); 136 } else { 137 throw new UnknownMessageTypeException("No message encoder found for message: " + message); 138 } 139 } 140 141 protected MessageEncoder<Object> findEncoder(State state, Class<?> type) { 142 return findEncoder(state, type, null); 143 } 144 145 @SuppressWarnings("unchecked") 146 private MessageEncoder<Object> findEncoder(State state, Class<?> type, Set<Class<?>> triedClasses) { 147 @SuppressWarnings("rawtypes") 148 MessageEncoder encoder = null; 149 150 if (triedClasses != null && triedClasses.contains(type)) { 151 return null; 152 } 153 154 /* 155 * Try the cache first. 156 */ 157 encoder = state.findEncoderCache.get(type); 158 159 if (encoder != null) { 160 return encoder; 161 } 162 163 /* 164 * Try the registered encoders for an immediate match. 165 */ 166 encoder = state.type2encoder.get(type); 167 168 if (encoder == null) { 169 /* 170 * No immediate match could be found. Search the type's interfaces. 171 */ 172 173 if (triedClasses == null) { 174 triedClasses = new IdentityHashSet<Class<?>>(); 175 } 176 177 triedClasses.add(type); 178 179 Class<?>[] interfaces = type.getInterfaces(); 180 181 for (Class<?> element : interfaces) { 182 encoder = findEncoder(state, element, triedClasses); 183 184 if (encoder != null) { 185 break; 186 } 187 } 188 } 189 190 if (encoder == null) { 191 /* 192 * No match in type's interfaces could be found. Search the 193 * superclass. 194 */ 195 196 Class<?> superclass = type.getSuperclass(); 197 198 if (superclass != null) { 199 encoder = findEncoder(state, superclass); 200 } 201 } 202 203 /* 204 * Make sure the encoder is added to the cache. By updating the cache 205 * here all the types (superclasses and interfaces) in the path which 206 * led to a match will be cached along with the immediate message type. 207 */ 208 if (encoder != null) { 209 state.findEncoderCache.put(type, encoder); 210 MessageEncoder<Object> tmpEncoder = state.findEncoderCache.putIfAbsent(type, encoder); 211 212 if (tmpEncoder != null) { 213 encoder = tmpEncoder; 214 } 215 } 216 217 return encoder; 218 } 219 220 /** 221 * {@inheritDoc} 222 */ 223 public void dispose(IoSession session) throws Exception { 224 session.removeAttribute(STATE); 225 } 226 227 private State getState(IoSession session) throws Exception { 228 State state = (State) session.getAttribute(STATE); 229 if (state == null) { 230 state = new State(); 231 State oldState = (State) session.setAttributeIfAbsent(STATE, state); 232 if (oldState != null) { 233 state = oldState; 234 } 235 } 236 return state; 237 } 238 239 private class State { 240 @SuppressWarnings("rawtypes") 241 private final ConcurrentHashMap<Class<?>, MessageEncoder> findEncoderCache = new ConcurrentHashMap<Class<?>, MessageEncoder>(); 242 243 @SuppressWarnings("rawtypes") 244 private final Map<Class<?>, MessageEncoder> type2encoder = new ConcurrentHashMap<Class<?>, MessageEncoder>(); 245 246 @SuppressWarnings("rawtypes") 247 private State() throws Exception { 248 for (Map.Entry<Class<?>, MessageEncoderFactory> e : type2encoderFactory.entrySet()) { 249 type2encoder.put(e.getKey(), e.getValue().getEncoder()); 250 } 251 } 252 } 253 254 private static class SingletonMessageEncoderFactory<T> implements MessageEncoderFactory<T> { 255 private final MessageEncoder<T> encoder; 256 257 private SingletonMessageEncoderFactory(MessageEncoder<T> encoder) { 258 if (encoder == null) { 259 throw new IllegalArgumentException("encoder"); 260 } 261 this.encoder = encoder; 262 } 263 264 public MessageEncoder<T> getEncoder() { 265 return encoder; 266 } 267 } 268 269 private static class DefaultConstructorMessageEncoderFactory<T> implements MessageEncoderFactory<T> { 270 private final Class<MessageEncoder<T>> encoderClass; 271 272 private DefaultConstructorMessageEncoderFactory(Class<MessageEncoder<T>> encoderClass) { 273 if (encoderClass == null) { 274 throw new IllegalArgumentException("encoderClass"); 275 } 276 277 if (!MessageEncoder.class.isAssignableFrom(encoderClass)) { 278 throw new IllegalArgumentException("encoderClass is not assignable to MessageEncoder"); 279 } 280 this.encoderClass = encoderClass; 281 } 282 283 public MessageEncoder<T> getEncoder() throws Exception { 284 return encoderClass.newInstance(); 285 } 286 } 287}