001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.activemq.transport.auto; 018 019import java.io.IOException; 020import java.io.InputStream; 021import java.net.Socket; 022import java.net.URI; 023import java.net.URISyntaxException; 024import java.nio.ByteBuffer; 025import java.util.Map; 026import java.util.Set; 027import java.util.concurrent.ConcurrentHashMap; 028import java.util.concurrent.ConcurrentMap; 029import java.util.concurrent.ExecutorService; 030import java.util.concurrent.Executors; 031import java.util.concurrent.Future; 032import java.util.concurrent.LinkedBlockingQueue; 033import java.util.concurrent.ThreadPoolExecutor; 034import java.util.concurrent.TimeUnit; 035import java.util.concurrent.TimeoutException; 036import java.util.concurrent.atomic.AtomicInteger; 037 038import javax.net.ServerSocketFactory; 039 040import org.apache.activemq.broker.BrokerService; 041import org.apache.activemq.broker.BrokerServiceAware; 042import org.apache.activemq.openwire.OpenWireFormatFactory; 043import org.apache.activemq.transport.InactivityIOException; 044import org.apache.activemq.transport.Transport; 045import org.apache.activemq.transport.TransportFactory; 046import org.apache.activemq.transport.TransportServer; 047import org.apache.activemq.transport.protocol.AmqpProtocolVerifier; 048import org.apache.activemq.transport.protocol.MqttProtocolVerifier; 049import org.apache.activemq.transport.protocol.OpenWireProtocolVerifier; 050import org.apache.activemq.transport.protocol.ProtocolVerifier; 051import org.apache.activemq.transport.protocol.StompProtocolVerifier; 052import org.apache.activemq.transport.tcp.TcpTransport; 053import org.apache.activemq.transport.tcp.TcpTransport.InitBuffer; 054import org.apache.activemq.transport.tcp.TcpTransportFactory; 055import org.apache.activemq.transport.tcp.TcpTransportServer; 056import org.apache.activemq.util.FactoryFinder; 057import org.apache.activemq.util.IOExceptionSupport; 058import org.apache.activemq.util.IntrospectionSupport; 059import org.apache.activemq.util.ServiceStopper; 060import org.apache.activemq.wireformat.WireFormat; 061import org.apache.activemq.wireformat.WireFormatFactory; 062import org.slf4j.Logger; 063import org.slf4j.LoggerFactory; 064 065/** 066 * A TCP based implementation of {@link TransportServer} 067 */ 068public class AutoTcpTransportServer extends TcpTransportServer { 069 070 private static final Logger LOG = LoggerFactory.getLogger(AutoTcpTransportServer.class); 071 072 protected Map<String, Map<String, Object>> wireFormatOptions; 073 protected Map<String, Object> autoTransportOptions; 074 protected Set<String> enabledProtocols; 075 protected final Map<String, ProtocolVerifier> protocolVerifiers = new ConcurrentHashMap<String, ProtocolVerifier>(); 076 077 protected BrokerService brokerService; 078 079 protected int maxConnectionThreadPoolSize = Integer.MAX_VALUE; 080 protected int protocolDetectionTimeOut = 30000; 081 082 private static final FactoryFinder TRANSPORT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/transport/"); 083 private final ConcurrentMap<String, TransportFactory> transportFactories = new ConcurrentHashMap<String, TransportFactory>(); 084 085 private static final FactoryFinder WIREFORMAT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/wireformat/"); 086 087 public WireFormatFactory findWireFormatFactory(String scheme, Map<String, Map<String, Object>> options) throws IOException { 088 WireFormatFactory wff = null; 089 try { 090 wff = (WireFormatFactory)WIREFORMAT_FACTORY_FINDER.newInstance(scheme); 091 if (options != null) { 092 IntrospectionSupport.setProperties(wff, options.get(AutoTransportUtils.ALL)); 093 IntrospectionSupport.setProperties(wff, options.get(scheme)); 094 } 095 if (wff instanceof OpenWireFormatFactory) { 096 protocolVerifiers.put(AutoTransportUtils.OPENWIRE, new OpenWireProtocolVerifier((OpenWireFormatFactory) wff)); 097 } 098 return wff; 099 } catch (Throwable e) { 100 throw IOExceptionSupport.create("Could not create wire format factory for: " + scheme + ", reason: " + e, e); 101 } 102 } 103 104 public TransportFactory findTransportFactory(String scheme, Map<String, ?> options) throws IOException { 105 scheme = append(scheme, "nio"); 106 scheme = append(scheme, "ssl"); 107 108 if (scheme.isEmpty()) { 109 scheme = "tcp"; 110 } 111 112 TransportFactory tf = transportFactories.get(scheme); 113 if (tf == null) { 114 // Try to load if from a META-INF property. 115 try { 116 tf = (TransportFactory)TRANSPORT_FACTORY_FINDER.newInstance(scheme); 117 if (options != null) { 118 IntrospectionSupport.setProperties(tf, options); 119 } 120 transportFactories.put(scheme, tf); 121 } catch (Throwable e) { 122 throw IOExceptionSupport.create("Transport scheme NOT recognized: [" + scheme + "]", e); 123 } 124 } 125 return tf; 126 } 127 128 protected String append(String currentScheme, String scheme) { 129 if (this.getBindLocation().getScheme().contains(scheme)) { 130 if (!currentScheme.isEmpty()) { 131 currentScheme += "+"; 132 } 133 currentScheme += scheme; 134 } 135 return currentScheme; 136 } 137 138 /** 139 * @param transportFactory 140 * @param location 141 * @param serverSocketFactory 142 * @throws IOException 143 * @throws URISyntaxException 144 */ 145 public AutoTcpTransportServer(TcpTransportFactory transportFactory, 146 URI location, ServerSocketFactory serverSocketFactory, BrokerService brokerService, 147 Set<String> enabledProtocols) 148 throws IOException, URISyntaxException { 149 super(transportFactory, location, serverSocketFactory); 150 151 //Use an executor service here to handle new connections. Setting the max number 152 //of threads to the maximum number of connections the thread count isn't unbounded 153 service = new ThreadPoolExecutor(maxConnectionThreadPoolSize, 154 maxConnectionThreadPoolSize, 155 30L, TimeUnit.SECONDS, 156 new LinkedBlockingQueue<Runnable>()); 157 //allow the thread pool to shrink if the max number of threads isn't needed 158 service.allowCoreThreadTimeOut(true); 159 160 this.brokerService = brokerService; 161 this.enabledProtocols = enabledProtocols; 162 initProtocolVerifiers(); 163 } 164 165 public int getMaxConnectionThreadPoolSize() { 166 return maxConnectionThreadPoolSize; 167 } 168 169 public void setMaxConnectionThreadPoolSize(int maxConnectionThreadPoolSize) { 170 this.maxConnectionThreadPoolSize = maxConnectionThreadPoolSize; 171 service.setCorePoolSize(maxConnectionThreadPoolSize); 172 service.setMaximumPoolSize(maxConnectionThreadPoolSize); 173 } 174 175 public void setProtocolDetectionTimeOut(int protocolDetectionTimeOut) { 176 this.protocolDetectionTimeOut = protocolDetectionTimeOut; 177 } 178 179 @Override 180 public void setWireFormatFactory(WireFormatFactory factory) { 181 super.setWireFormatFactory(factory); 182 initOpenWireProtocolVerifier(); 183 } 184 185 protected void initProtocolVerifiers() { 186 initOpenWireProtocolVerifier(); 187 188 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.AMQP)) { 189 protocolVerifiers.put(AutoTransportUtils.AMQP, new AmqpProtocolVerifier()); 190 } 191 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.STOMP)) { 192 protocolVerifiers.put(AutoTransportUtils.STOMP, new StompProtocolVerifier()); 193 } 194 if (isAllProtocols()|| enabledProtocols.contains(AutoTransportUtils.MQTT)) { 195 protocolVerifiers.put(AutoTransportUtils.MQTT, new MqttProtocolVerifier()); 196 } 197 } 198 199 protected void initOpenWireProtocolVerifier() { 200 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.OPENWIRE)) { 201 OpenWireProtocolVerifier owpv; 202 if (wireFormatFactory instanceof OpenWireFormatFactory) { 203 owpv = new OpenWireProtocolVerifier((OpenWireFormatFactory) wireFormatFactory); 204 } else { 205 owpv = new OpenWireProtocolVerifier(new OpenWireFormatFactory()); 206 } 207 protocolVerifiers.put(AutoTransportUtils.OPENWIRE, owpv); 208 } 209 } 210 211 protected boolean isAllProtocols() { 212 return enabledProtocols == null || enabledProtocols.isEmpty(); 213 } 214 215 216 protected final ThreadPoolExecutor service; 217 218 219 /** 220 * This holds the initial buffer that has been read to detect the protocol. 221 */ 222 public InitBuffer initBuffer; 223 224 @Override 225 protected void handleSocket(final Socket socket) { 226 final AutoTcpTransportServer server = this; 227 //This needs to be done in a new thread because 228 //the socket might be waiting on the client to send bytes 229 //doHandleSocket can't complete until the protocol can be detected 230 service.submit(new Runnable() { 231 @Override 232 public void run() { 233 server.doHandleSocket(socket); 234 } 235 }); 236 } 237 238 @Override 239 protected TransportInfo configureTransport(final TcpTransportServer server, final Socket socket) throws Exception { 240 final InputStream is = socket.getInputStream(); 241 ExecutorService executor = Executors.newSingleThreadExecutor(); 242 243 final AtomicInteger readBytes = new AtomicInteger(0); 244 final ByteBuffer data = ByteBuffer.allocate(8); 245 // We need to peak at the first 8 bytes of the buffer to detect the protocol 246 Future<?> future = executor.submit(new Runnable() { 247 @Override 248 public void run() { 249 try { 250 do { 251 int read = is.read(); 252 if (read == -1) { 253 throw new IOException("Connection failed, stream is closed."); 254 } 255 data.put((byte) read); 256 readBytes.incrementAndGet(); 257 } while (readBytes.get() < 8); 258 } catch (Exception e) { 259 throw new IllegalStateException(e); 260 } 261 } 262 }); 263 264 waitForProtocolDetectionFinish(future, readBytes); 265 data.flip(); 266 ProtocolInfo protocolInfo = detectProtocol(data.array()); 267 268 initBuffer = new InitBuffer(readBytes.get(), ByteBuffer.allocate(readBytes.get())); 269 initBuffer.buffer.put(data.array()); 270 271 if (protocolInfo.detectedTransportFactory instanceof BrokerServiceAware) { 272 ((BrokerServiceAware) protocolInfo.detectedTransportFactory).setBrokerService(brokerService); 273 } 274 275 WireFormat format = protocolInfo.detectedWireFormatFactory.createWireFormat(); 276 Transport transport = createTransport(socket, format,protocolInfo.detectedTransportFactory); 277 278 return new TransportInfo(format, transport, protocolInfo.detectedTransportFactory); 279 } 280 281 protected void waitForProtocolDetectionFinish(final Future<?> future, final AtomicInteger readBytes) throws Exception { 282 try { 283 //Wait for protocolDetectionTimeOut if defined 284 if (protocolDetectionTimeOut > 0) { 285 future.get(protocolDetectionTimeOut, TimeUnit.MILLISECONDS); 286 } else { 287 future.get(); 288 } 289 } catch (TimeoutException e) { 290 throw new InactivityIOException("Client timed out before wire format could be detected. " + 291 " 8 bytes are required to detect the protocol but only: " + readBytes.get() + " byte(s) were sent."); 292 } 293 } 294 295 @Override 296 protected TcpTransport createTransport(Socket socket, WireFormat format) throws IOException { 297 return new TcpTransport(format, socket, this.initBuffer); 298 } 299 300 /** 301 * @param socket 302 * @param format 303 * @param detectedTransportFactory 304 * @return 305 */ 306 protected TcpTransport createTransport(Socket socket, WireFormat format, 307 TcpTransportFactory detectedTransportFactory) throws IOException { 308 return createTransport(socket, format); 309 } 310 311 public void setWireFormatOptions(Map<String, Map<String, Object>> wireFormatOptions) { 312 this.wireFormatOptions = wireFormatOptions; 313 } 314 315 public void setEnabledProtocols(Set<String> enabledProtocols) { 316 this.enabledProtocols = enabledProtocols; 317 } 318 319 public void setAutoTransportOptions(Map<String, Object> autoTransportOptions) { 320 this.autoTransportOptions = autoTransportOptions; 321 if (autoTransportOptions.get("protocols") != null) { 322 this.enabledProtocols = AutoTransportUtils.parseProtocols((String) autoTransportOptions.get("protocols")); 323 } 324 } 325 @Override 326 protected void doStop(ServiceStopper stopper) throws Exception { 327 if (service != null) { 328 service.shutdown(); 329 } 330 super.doStop(stopper); 331 } 332 333 protected ProtocolInfo detectProtocol(byte[] buffer) throws IOException { 334 TcpTransportFactory detectedTransportFactory = transportFactory; 335 WireFormatFactory detectedWireFormatFactory = wireFormatFactory; 336 337 boolean found = false; 338 for (String scheme : protocolVerifiers.keySet()) { 339 if (protocolVerifiers.get(scheme).isProtocol(buffer)) { 340 LOG.debug("Detected protocol " + scheme); 341 detectedWireFormatFactory = findWireFormatFactory(scheme, wireFormatOptions); 342 343 if (scheme.equals("default")) { 344 scheme = ""; 345 } 346 347 detectedTransportFactory = (TcpTransportFactory) findTransportFactory(scheme, transportOptions); 348 found = true; 349 break; 350 } 351 } 352 353 if (!found) { 354 throw new IllegalStateException("Could not detect the wire format"); 355 } 356 357 return new ProtocolInfo(detectedTransportFactory, detectedWireFormatFactory); 358 359 } 360 361 protected class ProtocolInfo { 362 public final TcpTransportFactory detectedTransportFactory; 363 public final WireFormatFactory detectedWireFormatFactory; 364 365 public ProtocolInfo(TcpTransportFactory detectedTransportFactory, 366 WireFormatFactory detectedWireFormatFactory) { 367 super(); 368 this.detectedTransportFactory = detectedTransportFactory; 369 this.detectedWireFormatFactory = detectedWireFormatFactory; 370 } 371 } 372 373}