alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Akka/Scala example source code file (NettyTransport.scala)

This example Akka source code file (NettyTransport.scala) is included in my "Source Code Warehouse" project. The intent of this project is to help you more easily find Akka and Scala source code examples by using tags.

All credit for the original source code belongs to akka.io; I'm just trying to make examples easier to find. (For my Scala work, see my Scala examples and tutorials.)

Akka tags/keywords

address, akka, channel, concurrent, future, int, option, remote, some, string, time, transport, transportmode, udp, unit, utilities

The NettyTransport.scala Akka example source code

/**
 * Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
 */
package akka.remote.transport.netty

import akka.actor.{ Address, ExtendedActorSystem }
import akka.dispatch.ThreadPoolConfig
import akka.event.Logging
import akka.remote.transport.AssociationHandle.HandleEventListener
import akka.remote.transport.Transport._
import akka.remote.transport.netty.NettyTransportSettings.{ Udp, Tcp, Mode }
import akka.remote.transport.{ AssociationHandle, Transport }
import akka.{ OnlyCauseStackTrace, ConfigurationException }
import com.typesafe.config.Config
import java.net.{ UnknownHostException, SocketAddress, InetAddress, InetSocketAddress, ConnectException }
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{ ConcurrentHashMap, Executors, CancellationException }
import org.jboss.netty.bootstrap.{ ConnectionlessBootstrap, Bootstrap, ClientBootstrap, ServerBootstrap }
import org.jboss.netty.buffer.{ ChannelBuffers, ChannelBuffer }
import org.jboss.netty.channel._
import org.jboss.netty.channel.group.{ DefaultChannelGroup, ChannelGroup, ChannelGroupFuture, ChannelGroupFutureListener }
import org.jboss.netty.channel.socket.nio.{ NioWorkerPool, NioDatagramChannelFactory, NioServerSocketChannelFactory, NioClientSocketChannelFactory }
import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender }
import org.jboss.netty.handler.ssl.SslHandler
import scala.concurrent.duration.{ Duration, FiniteDuration, MILLISECONDS }
import scala.concurrent.{ ExecutionContext, Promise, Future, blocking }
import scala.util.{ Failure, Success, Try }
import scala.util.control.{ NoStackTrace, NonFatal }
import akka.util.Helpers.Requiring
import akka.util.Helpers
import akka.remote.RARP
import org.jboss.netty.util.HashedWheelTimer

object NettyTransportSettings {
  sealed trait Mode
  case object Tcp extends Mode { override def toString = "tcp" }
  case object Udp extends Mode { override def toString = "udp" }
}

object NettyFutureBridge {
  def apply(nettyFuture: ChannelFuture): Future[Channel] = {
    val p = Promise[Channel]()
    nettyFuture.addListener(new ChannelFutureListener {
      def operationComplete(future: ChannelFuture): Unit = p complete Try(
        if (future.isSuccess) future.getChannel
        else if (future.isCancelled) throw new CancellationException
        else throw future.getCause)
    })
    p.future
  }

  def apply(nettyFuture: ChannelGroupFuture): Future[ChannelGroup] = {
    import scala.collection.JavaConverters._
    val p = Promise[ChannelGroup]
    nettyFuture.addListener(new ChannelGroupFutureListener {
      def operationComplete(future: ChannelGroupFuture): Unit = p complete Try(
        if (future.isCompleteSuccess) future.getGroup
        else throw future.iterator.asScala.collectFirst {
          case f if f.isCancelled ⇒ new CancellationException
          case f if !f.isSuccess  ⇒ f.getCause
        } getOrElse new IllegalStateException("Error reported in ChannelGroupFuture, but no error found in individual futures."))
    })
    p.future
  }
}

@SerialVersionUID(1L)
class NettyTransportException(msg: String, cause: Throwable) extends RuntimeException(msg, cause) with OnlyCauseStackTrace {
  def this(msg: String) = this(msg, null)
}

class NettyTransportSettings(config: Config) {

  import akka.util.Helpers.ConfigOps
  import config._

  val TransportMode: Mode = getString("transport-protocol") match {
    case "tcp"   ⇒ Tcp
    case "udp"   ⇒ Udp
    case unknown ⇒ throw new ConfigurationException(s"Unknown transport: [$unknown]")
  }

  val EnableSsl: Boolean = getBoolean("enable-ssl") requiring (!_ || TransportMode == Tcp, s"$TransportMode does not support SSL")

  val UseDispatcherForIo: Option[String] = getString("use-dispatcher-for-io") match {
    case "" | null  ⇒ None
    case dispatcher ⇒ Some(dispatcher)
  }

  private[this] def optionSize(s: String): Option[Int] = getBytes(s).toInt match {
    case 0          ⇒ None
    case x if x < 0 ⇒ throw new ConfigurationException(s"Setting '$s' must be 0 or positive (and fit in an Int)")
    case other      ⇒ Some(other)
  }

  val ConnectionTimeout: FiniteDuration = config.getMillisDuration("connection-timeout")

  val WriteBufferHighWaterMark: Option[Int] = optionSize("write-buffer-high-water-mark")

  val WriteBufferLowWaterMark: Option[Int] = optionSize("write-buffer-low-water-mark")

  val SendBufferSize: Option[Int] = optionSize("send-buffer-size")

  val ReceiveBufferSize: Option[Int] = optionSize("receive-buffer-size") requiring (s ⇒
    s.isDefined || TransportMode != Udp, "receive-buffer-size must be specified for UDP")

  val MaxFrameSize: Int = getBytes("maximum-frame-size").toInt requiring (
    _ >= 32000,
    s"Setting 'maximum-frame-size' must be at least 32000 bytes")

  val Backlog: Int = getInt("backlog")

  val TcpNodelay: Boolean = getBoolean("tcp-nodelay")

  val TcpKeepalive: Boolean = getBoolean("tcp-keepalive")

  val TcpReuseAddr: Boolean = getString("tcp-reuse-addr") match {
    case "off-for-windows" ⇒ !Helpers.isWindows
    case _                 ⇒ getBoolean("tcp-reuse-addr")
  }

  val Hostname: String = getString("hostname") match {
    case ""    ⇒ InetAddress.getLocalHost.getHostAddress
    case value ⇒ value
  }

  @deprecated("WARNING: This should only be used by professionals.", "2.0")
  val PortSelector: Int = getInt("port")

  val SslSettings: Option[SSLSettings] = if (EnableSsl) Some(new SSLSettings(config.getConfig("security"))) else None

  val ServerSocketWorkerPoolSize: Int = computeWPS(config.getConfig("server-socket-worker-pool"))

  val ClientSocketWorkerPoolSize: Int = computeWPS(config.getConfig("client-socket-worker-pool"))

  private def computeWPS(config: Config): Int =
    ThreadPoolConfig.scaledPoolSize(
      config.getInt("pool-size-min"),
      config.getDouble("pool-size-factor"),
      config.getInt("pool-size-max"))

}

/**
 * INTERNAL API
 */
private[netty] trait CommonHandlers extends NettyHelpers {
  protected val transport: NettyTransport

  final override def onOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = transport.channelGroup.add(e.getChannel)

  protected def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle

  protected def registerListener(channel: Channel,
                                 listener: HandleEventListener,
                                 msg: ChannelBuffer,
                                 remoteSocketAddress: InetSocketAddress): Unit

  final protected def init(channel: Channel, remoteSocketAddress: SocketAddress, remoteAddress: Address, msg: ChannelBuffer)(
    op: (AssociationHandle ⇒ Any)): Unit = {
    import transport._
    NettyTransport.addressFromSocketAddress(channel.getLocalAddress, schemeIdentifier, system.name, Some(settings.Hostname)) match {
      case Some(localAddress) ⇒
        val handle = createHandle(channel, localAddress, remoteAddress)
        handle.readHandlerPromise.future.onSuccess {
          case listener: HandleEventListener ⇒
            registerListener(channel, listener, msg, remoteSocketAddress.asInstanceOf[InetSocketAddress])
            channel.setReadable(true)
        }
        op(handle)

      case _ ⇒ NettyTransport.gracefulClose(channel)
    }
  }
}

/**
 * INTERNAL API
 */
private[netty] abstract class ServerHandler(protected final val transport: NettyTransport,
                                            private final val associationListenerFuture: Future[AssociationEventListener])
  extends NettyServerHelpers with CommonHandlers {

  import transport.executionContext

  final protected def initInbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = {
    channel.setReadable(false)
    associationListenerFuture.onSuccess {
      case listener: AssociationEventListener ⇒
        val remoteAddress = NettyTransport.addressFromSocketAddress(remoteSocketAddress, transport.schemeIdentifier,
          transport.system.name, hostName = None).getOrElse(
            throw new NettyTransportException(s"Unknown inbound remote address type [${remoteSocketAddress.getClass.getName}]"))
        init(channel, remoteSocketAddress, remoteAddress, msg) { listener notify InboundAssociation(_) }
    }
  }

}

/**
 * INTERNAL API
 */
private[netty] abstract class ClientHandler(protected final val transport: NettyTransport, remoteAddress: Address)
  extends NettyClientHelpers with CommonHandlers {
  final protected val statusPromise = Promise[AssociationHandle]()
  def statusFuture = statusPromise.future

  final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = {
    init(channel, remoteSocketAddress, remoteAddress, msg)(statusPromise.success)
  }

}

/**
 * INTERNAL API
 */
private[transport] object NettyTransport {
  // 4 bytes will be used to represent the frame length. Used by netty LengthFieldPrepender downstream handler.
  val FrameLengthFieldLength = 4
  def gracefulClose(channel: Channel)(implicit ec: ExecutionContext): Unit = {
    def always(c: ChannelFuture) = NettyFutureBridge(c) recover { case _ ⇒ c.getChannel }
    for {
      _ ← always { channel.write(ChannelBuffers.buffer(0)) } // Force flush by waiting on a final dummy write
      _ ← always { channel.disconnect() }
    } channel.close()
  }

  val uniqueIdCounter = new AtomicInteger(0)

  def addressFromSocketAddress(addr: SocketAddress, schemeIdentifier: String, systemName: String,
                               hostName: Option[String]): Option[Address] = addr match {
    case sa: InetSocketAddress ⇒ Some(Address(schemeIdentifier, systemName,
      hostName.getOrElse(sa.getAddress.getHostAddress), sa.getPort)) // perhaps use getHostString in jdk 1.7
    case _ ⇒ None
  }
}

// FIXME: Split into separate UDP and TCP classes
class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedActorSystem) extends Transport {

  def this(system: ExtendedActorSystem, conf: Config) = this(new NettyTransportSettings(conf), system)

  import NettyTransport._
  import settings._

  implicit val executionContext: ExecutionContext =
    settings.UseDispatcherForIo.orElse(RARP(system).provider.remoteSettings.Dispatcher match {
      case ""             ⇒ None
      case dispatcherName ⇒ Some(dispatcherName)
    }).map(system.dispatchers.lookup).getOrElse(system.dispatcher)

  override val schemeIdentifier: String = (if (EnableSsl) "ssl." else "") + TransportMode
  override def maximumPayloadBytes: Int = settings.MaxFrameSize

  private final val isDatagram = TransportMode == Udp

  @volatile private var localAddress: Address = _
  @volatile private var serverChannel: Channel = _

  private val log = Logging(system, this.getClass)

  /**
   * INTERNAL API
   */
  private[netty] final val udpConnectionTable = new ConcurrentHashMap[SocketAddress, HandleEventListener]()

  private def createExecutorService() =
    UseDispatcherForIo.map(system.dispatchers.lookup) getOrElse Executors.newCachedThreadPool(system.threadFactory)

  /*
   * Be aware, that the close() method of DefaultChannelGroup is racy, because it uses an iterator over a ConcurrentHashMap.
   * In the old remoting this was handled by using a custom subclass, guarding the close() method with a write-lock.
   * The usage of this class is safe in the new remoting, as close() is called after unbind() is finished, and no
   * outbound connections are initiated in the shutdown phase.
   */
  val channelGroup = new DefaultChannelGroup("akka-netty-transport-driver-channelgroup-" +
    uniqueIdCounter.getAndIncrement)

  private val clientChannelFactory: ChannelFactory = TransportMode match {
    case Tcp ⇒
      val boss, worker = createExecutorService()
      // We need to create a HashedWheelTimer here since Netty creates one with a thread that
      // doesn't respect the akka.daemonic setting
      new NioClientSocketChannelFactory(boss, 1, new NioWorkerPool(worker, ClientSocketWorkerPoolSize),
        new HashedWheelTimer(system.threadFactory))
    case Udp ⇒
      // This does not create a HashedWheelTimer internally
      new NioDatagramChannelFactory(createExecutorService(), ClientSocketWorkerPoolSize)
  }

  private val serverChannelFactory: ChannelFactory = TransportMode match {
    case Tcp ⇒
      val boss, worker = createExecutorService()
      // This does not create a HashedWheelTimer internally
      new NioServerSocketChannelFactory(boss, worker, ServerSocketWorkerPoolSize)
    case Udp ⇒
      // This does not create a HashedWheelTimer internally
      new NioDatagramChannelFactory(createExecutorService(), ServerSocketWorkerPoolSize)
  }

  private def newPipeline: DefaultChannelPipeline = {
    val pipeline = new DefaultChannelPipeline

    if (!isDatagram) {
      pipeline.addLast("FrameDecoder", new LengthFieldBasedFrameDecoder(
        maximumPayloadBytes,
        0,
        FrameLengthFieldLength,
        0,
        FrameLengthFieldLength, // Strip the header
        true))
      pipeline.addLast("FrameEncoder", new LengthFieldPrepender(FrameLengthFieldLength))
    }

    pipeline
  }

  private val associationListenerPromise: Promise[AssociationEventListener] = Promise()

  private def sslHandler(isClient: Boolean): SslHandler = {
    val handler = NettySSLSupport(settings.SslSettings.get, log, isClient)
    handler.setCloseOnSSLException(true)
    handler
  }

  private val serverPipelineFactory: ChannelPipelineFactory = new ChannelPipelineFactory {
    override def getPipeline: ChannelPipeline = {
      val pipeline = newPipeline
      if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = false))
      val handler = if (isDatagram) new UdpServerHandler(NettyTransport.this, associationListenerPromise.future)
      else new TcpServerHandler(NettyTransport.this, associationListenerPromise.future)
      pipeline.addLast("ServerHandler", handler)
      pipeline
    }
  }

  private def clientPipelineFactory(remoteAddress: Address): ChannelPipelineFactory =
    new ChannelPipelineFactory {
      override def getPipeline: ChannelPipeline = {
        val pipeline = newPipeline
        if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true))
        val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this, remoteAddress)
        else new TcpClientHandler(NettyTransport.this, remoteAddress)
        pipeline.addLast("clienthandler", handler)
        pipeline
      }
    }

  private def setupBootstrap[B <: Bootstrap](bootstrap: B, pipelineFactory: ChannelPipelineFactory): B = {
    bootstrap.setPipelineFactory(pipelineFactory)
    bootstrap.setOption("backlog", settings.Backlog)
    bootstrap.setOption("tcpNoDelay", settings.TcpNodelay)
    bootstrap.setOption("child.keepAlive", settings.TcpKeepalive)
    bootstrap.setOption("reuseAddress", settings.TcpReuseAddr)
    if (isDatagram) bootstrap.setOption("receiveBufferSizePredictorFactory", new FixedReceiveBufferSizePredictorFactory(ReceiveBufferSize.get))
    settings.ReceiveBufferSize.foreach(sz ⇒ bootstrap.setOption("receiveBufferSize", sz))
    settings.SendBufferSize.foreach(sz ⇒ bootstrap.setOption("sendBufferSize", sz))
    settings.WriteBufferHighWaterMark.foreach(sz ⇒ bootstrap.setOption("writeBufferHighWaterMark", sz))
    settings.WriteBufferLowWaterMark.foreach(sz ⇒ bootstrap.setOption("writeBufferLowWaterMark", sz))
    bootstrap
  }

  private val inboundBootstrap: Bootstrap = settings.TransportMode match {
    case Tcp ⇒ setupBootstrap(new ServerBootstrap(serverChannelFactory), serverPipelineFactory)
    case Udp ⇒ setupBootstrap(new ConnectionlessBootstrap(serverChannelFactory), serverPipelineFactory)
  }

  private def outboundBootstrap(remoteAddress: Address): ClientBootstrap = {
    val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory(remoteAddress))
    bootstrap.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis)
    bootstrap.setOption("tcpNoDelay", settings.TcpNodelay)
    bootstrap.setOption("keepAlive", settings.TcpKeepalive)
    settings.ReceiveBufferSize.foreach(sz ⇒ bootstrap.setOption("receiveBufferSize", sz))
    settings.SendBufferSize.foreach(sz ⇒ bootstrap.setOption("sendBufferSize", sz))
    settings.WriteBufferHighWaterMark.foreach(sz ⇒ bootstrap.setOption("writeBufferHighWaterMark", sz))
    settings.WriteBufferLowWaterMark.foreach(sz ⇒ bootstrap.setOption("writeBufferLowWaterMark", sz))
    bootstrap
  }

  override def isResponsibleFor(address: Address): Boolean = true //TODO: Add configurable subnet filtering

  // TODO: This should be factored out to an async (or thread-isolated) name lookup service #2960
  def addressToSocketAddress(addr: Address): Future[InetSocketAddress] = addr match {
    case Address(_, _, Some(host), Some(port)) ⇒ Future { blocking { new InetSocketAddress(InetAddress.getByName(host), port) } }
    case _                                     ⇒ Future.failed(new IllegalArgumentException(s"Address [$addr] does not contain host or port information."))
  }

  override def listen: Future[(Address, Promise[AssociationEventListener])] = {
    for {
      address ← addressToSocketAddress(Address("", "", settings.Hostname, settings.PortSelector))
    } yield {
      try {
        val newServerChannel = inboundBootstrap match {
          case b: ServerBootstrap         ⇒ b.bind(address)
          case b: ConnectionlessBootstrap ⇒ b.bind(address)
        }

        // Block reads until a handler actor is registered
        newServerChannel.setReadable(false)
        channelGroup.add(newServerChannel)

        serverChannel = newServerChannel

        addressFromSocketAddress(newServerChannel.getLocalAddress, schemeIdentifier, system.name, Some(settings.Hostname)) match {
          case Some(address) ⇒
            localAddress = address
            associationListenerPromise.future.onSuccess { case listener ⇒ newServerChannel.setReadable(true) }
            (address, associationListenerPromise)
          case None ⇒ throw new NettyTransportException(s"Unknown local address type [${newServerChannel.getLocalAddress.getClass.getName}]")
        }
      } catch {
        case NonFatal(e) ⇒ {
          log.error("failed to bind to {}, shutting down Netty transport", address)
          try { shutdown() } catch { case NonFatal(e) ⇒ } // ingore possible exception during shutdown
          throw e
        }
      }
    }
  }

  override def associate(remoteAddress: Address): Future[AssociationHandle] = {
    if (!serverChannel.isBound) Future.failed(new NettyTransportException("Transport is not bound"))
    else {
      val bootstrap: ClientBootstrap = outboundBootstrap(remoteAddress)

      (for {
        socketAddress ← addressToSocketAddress(remoteAddress)
        readyChannel ← NettyFutureBridge(bootstrap.connect(socketAddress)) map {
          channel ⇒
            if (EnableSsl)
              blocking {
                channel.getPipeline.get(classOf[SslHandler]).handshake().awaitUninterruptibly()
              }
            if (!isDatagram) channel.setReadable(false)
            channel
        }
        handle ← if (isDatagram)
          Future {
            readyChannel.getRemoteAddress match {
              case addr: InetSocketAddress ⇒
                val handle = new UdpAssociationHandle(localAddress, remoteAddress, readyChannel, NettyTransport.this)
                handle.readHandlerPromise.future.onSuccess {
                  case listener ⇒ udpConnectionTable.put(addr, listener)
                }
                handle
              case unknown ⇒ throw new NettyTransportException(s"Unknown outbound remote address type [${unknown.getClass.getName}]")
            }
          }
        else
          readyChannel.getPipeline.get(classOf[ClientHandler]).statusFuture
      } yield handle) recover {
        case c: CancellationException ⇒ throw new NettyTransportException("Connection was cancelled") with NoStackTrace
        case u @ (_: UnknownHostException | _: SecurityException | _: ConnectException) ⇒ throw new InvalidAssociationException(u.getMessage, u.getCause)
        case NonFatal(t) ⇒ throw new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace
      }
    }
  }

  override def shutdown(): Future[Boolean] = {
    def always(c: ChannelGroupFuture) = NettyFutureBridge(c).map(_ ⇒ true) recover { case _ ⇒ false }
    for {
      // Force flush by trying to write an empty buffer and wait for success
      unbindStatus ← always(channelGroup.unbind())
      lastWriteStatus ← always(channelGroup.write(ChannelBuffers.buffer(0)))
      disconnectStatus ← always(channelGroup.disconnect())
      closeStatus ← always(channelGroup.close())
    } yield {
      // Release the selectors, but don't try to kill the dispatcher
      if (UseDispatcherForIo.isDefined) {
        clientChannelFactory.shutdown()
        serverChannelFactory.shutdown()
      } else {
        clientChannelFactory.releaseExternalResources()
        serverChannelFactory.releaseExternalResources()
      }
      lastWriteStatus && unbindStatus && disconnectStatus && closeStatus
    }

  }

}

Other Akka source code examples

Here is a short list of links related to this Akka NettyTransport.scala source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2024 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.