001/*
002 *  Licensed to the Apache Software Foundation (ASF) under one
003 *  or more contributor license agreements.  See the NOTICE file
004 *  distributed with this work for additional information
005 *  regarding copyright ownership.  The ASF licenses this file
006 *  to you under the Apache License, Version 2.0 (the
007 *  "License"); you may not use this file except in compliance
008 *  with the License.  You may obtain a copy of the License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 *  Unless required by applicable law or agreed to in writing,
013 *  software distributed under the License is distributed on an
014 *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 *  KIND, either express or implied.  See the License for the
016 *  specific language governing permissions and limitations
017 *  under the License.
018 *
019 */
020package org.apache.mina.filter.firewall;
021
022import java.net.InetAddress;
023import java.net.InetSocketAddress;
024import java.net.SocketAddress;
025import java.util.List;
026import java.util.concurrent.CopyOnWriteArrayList;
027
028import org.apache.mina.core.filterchain.IoFilter;
029import org.apache.mina.core.filterchain.IoFilterAdapter;
030import org.apache.mina.core.session.IdleStatus;
031import org.apache.mina.core.session.IoSession;
032import org.apache.mina.core.write.WriteRequest;
033import org.slf4j.Logger;
034import org.slf4j.LoggerFactory;
035
036/**
037 * A {@link IoFilter} which blocks connections from blacklisted remote
038 * address.
039 *
040 * @author <a href="http://mina.apache.org">Apache MINA Project</a>
041 * @org.apache.xbean.XBean
042 */
043public class BlacklistFilter extends IoFilterAdapter {
044    /** The list of blocked addresses */
045    private final List<Subnet> blacklist = new CopyOnWriteArrayList<Subnet>();
046
047    private final static Logger LOGGER = LoggerFactory.getLogger(BlacklistFilter.class);
048
049    /**
050     * Sets the addresses to be blacklisted.
051     *
052     * NOTE: this call will remove any previously blacklisted addresses.
053     *
054     * @param addresses an array of addresses to be blacklisted.
055     */
056    public void setBlacklist(InetAddress[] addresses) {
057        if (addresses == null) {
058            throw new IllegalArgumentException("addresses");
059        }
060
061        blacklist.clear();
062
063        for (int i = 0; i < addresses.length; i++) {
064            InetAddress addr = addresses[i];
065            block(addr);
066        }
067    }
068
069    /**
070     * Sets the subnets to be blacklisted.
071     *
072     * NOTE: this call will remove any previously blacklisted subnets.
073     *
074     * @param subnets an array of subnets to be blacklisted.
075     */
076    public void setSubnetBlacklist(Subnet[] subnets) {
077        if (subnets == null) {
078            throw new IllegalArgumentException("Subnets must not be null");
079        }
080
081        blacklist.clear();
082
083        for (Subnet subnet : subnets) {
084            block(subnet);
085        }
086    }
087
088    /**
089     * Sets the addresses to be blacklisted.
090     *
091     * NOTE: this call will remove any previously blacklisted addresses.
092     *
093     * @param addresses a collection of InetAddress objects representing the
094     *        addresses to be blacklisted.
095     * @throws IllegalArgumentException if the specified collections contains
096     *         non-{@link InetAddress} objects.
097     */
098    public void setBlacklist(Iterable<InetAddress> addresses) {
099        if (addresses == null) {
100            throw new IllegalArgumentException("addresses");
101        }
102
103        blacklist.clear();
104
105        for (InetAddress address : addresses) {
106            block(address);
107        }
108    }
109
110    /**
111     * Sets the subnets to be blacklisted.
112     *
113     * NOTE: this call will remove any previously blacklisted subnets.
114     *
115     * @param subnets an array of subnets to be blacklisted.
116     */
117    public void setSubnetBlacklist(Iterable<Subnet> subnets) {
118        if (subnets == null) {
119            throw new IllegalArgumentException("Subnets must not be null");
120        }
121
122        blacklist.clear();
123
124        for (Subnet subnet : subnets) {
125            block(subnet);
126        }
127    }
128
129    /**
130     * Blocks the specified endpoint.
131     * 
132     * @param address The address to block
133     */
134    public void block(InetAddress address) {
135        if (address == null) {
136            throw new IllegalArgumentException("Adress to block can not be null");
137        }
138
139        block(new Subnet(address, 32));
140    }
141
142    /**
143     * Blocks the specified subnet.
144     * 
145     * @param subnet The subnet to block
146     */
147    public void block(Subnet subnet) {
148        if (subnet == null) {
149            throw new IllegalArgumentException("Subnet can not be null");
150        }
151
152        blacklist.add(subnet);
153    }
154
155    /**
156     * Unblocks the specified endpoint.
157     * 
158     * @param address The address to unblock
159     */
160    public void unblock(InetAddress address) {
161        if (address == null) {
162            throw new IllegalArgumentException("Adress to unblock can not be null");
163        }
164
165        unblock(new Subnet(address, 32));
166    }
167
168    /**
169     * Unblocks the specified subnet.
170     * 
171     * @param subnet The subnet to unblock
172     */
173    public void unblock(Subnet subnet) {
174        if (subnet == null) {
175            throw new IllegalArgumentException("Subnet can not be null");
176        }
177
178        blacklist.remove(subnet);
179    }
180
181    @Override
182    public void sessionCreated(NextFilter nextFilter, IoSession session) {
183        if (!isBlocked(session)) {
184            // forward if not blocked
185            nextFilter.sessionCreated(session);
186        } else {
187            blockSession(session);
188        }
189    }
190
191    @Override
192    public void sessionOpened(NextFilter nextFilter, IoSession session) throws Exception {
193        if (!isBlocked(session)) {
194            // forward if not blocked
195            nextFilter.sessionOpened(session);
196        } else {
197            blockSession(session);
198        }
199    }
200
201    @Override
202    public void sessionClosed(NextFilter nextFilter, IoSession session) throws Exception {
203        if (!isBlocked(session)) {
204            // forward if not blocked
205            nextFilter.sessionClosed(session);
206        } else {
207            blockSession(session);
208        }
209    }
210
211    @Override
212    public void sessionIdle(NextFilter nextFilter, IoSession session, IdleStatus status) throws Exception {
213        if (!isBlocked(session)) {
214            // forward if not blocked
215            nextFilter.sessionIdle(session, status);
216        } else {
217            blockSession(session);
218        }
219    }
220
221    @Override
222    public void messageReceived(NextFilter nextFilter, IoSession session, Object message) {
223        if (!isBlocked(session)) {
224            // forward if not blocked
225            nextFilter.messageReceived(session, message);
226        } else {
227            blockSession(session);
228        }
229    }
230
231    @Override
232    public void messageSent(NextFilter nextFilter, IoSession session, WriteRequest writeRequest) throws Exception {
233        if (!isBlocked(session)) {
234            // forward if not blocked
235            nextFilter.messageSent(session, writeRequest);
236        } else {
237            blockSession(session);
238        }
239    }
240
241    private void blockSession(IoSession session) {
242        LOGGER.warn("Remote address in the blacklist; closing.");
243        session.closeNow();
244    }
245
246    private boolean isBlocked(IoSession session) {
247        SocketAddress remoteAddress = session.getRemoteAddress();
248
249        if (remoteAddress instanceof InetSocketAddress) {
250            InetAddress address = ((InetSocketAddress) remoteAddress).getAddress();
251
252            // check all subnets
253            for (Subnet subnet : blacklist) {
254                if (subnet.inSubnet(address)) {
255                    return true;
256                }
257            }
258        }
259
260        return false;
261    }
262}