spark版本: 2.0.0
1.概念
spark是分布式服务,需要涉及到大量的网络通信以及远程服务调用(rpc),在1.6前spark使用的是akka实现,但是考虑到akka兼容性问题,最后舍弃,改为netty。这篇文章就将介绍基于netty的rpc服务是如何运作的。
在前一篇文章中介绍了master的启动过程,但是其中对rpcEnv这部分介绍的很少,所以我将从上篇文章创建rpcEnv位置说明spark中的服务是如何通信的。
2.rpc实现
2.1 rpc服务端实现
在master启动中有这样一句代码val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
用于创建rpcEnv环境,现在我们就来深入了解一下这段代码究竟在干什么?
RpcEnv.scala
------------------
def create(
name: String,
host: String,
port: Int,
conf: SparkConf,
securityManager: SecurityManager,
// clientMode=false,因为启动服务方一定是服务端
clientMode: Boolean = false): RpcEnv = {
// 封装rpcEnv配置对象
val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode)
// 使用基于netty的rpcEnv工厂(工厂模式),不过这里为了扩展方便可以使用反射方式创建对象
new NettyRpcEnvFactory().create(config)
}
上面核心代码:new NettyRpcEnvFactory().create(config)
NettyRpcEnv.scala
--------------------
// 创建RpcEnv对象
def create(config: RpcEnvConfig): RpcEnv = {
val sparkConf = config.conf
// Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
// KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
// 序列化方式,更好的方式:通过反射创建
val javaSerializerInstance =
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
// 创建NettyRpcEnv 【1】
val nettyEnv =
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
// 如果是服务端,需要启动服务 【2】
if (!config.clientMode) {
// 根据端口启动服务
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
nettyEnv.startServer(actualPort)
(nettyEnv, nettyEnv.address.port)
}
try {
// 启动服务
Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
} catch {
case NonFatal(e) =>
nettyEnv.shutdown()
throw e
}
}
nettyEnv
}
}
这段代码是非常关键的,所以分为了两个主要部分:
【1】 创建NettyRpcEnv
【2】 启动服务端
现在依次介绍这两个部分:
【1】
NettyRpcEnv.class
-------------------------
在创建NettyRpcEnv对象时,需要关注以下主要属性:
// (1) 将sparkConf转为SparkTransportConf(传输配置对象)
private[netty] val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
"rpc",
conf.getInt("spark.rpc.io.threads", 0))
// (2)分发消息
private val dispatcher: Dispatcher = new Dispatcher(this)
// (3)处理数据流
private val streamManager = new NettyStreamManager(this)
// (4)传输数据上下文
private val transportContext = new TransportContext(transportConf,
new NettyRpcHandler(dispatcher, this, streamManager))
(1)SparkTransportConf就是专门用来处理传输的配置对象
(2)在介绍master的启动过程中,也说过dispatcher的registerRpcEndpoint
方法用于注册endpoint,并将endpoint信息记录到endpoints,endpointRefs两个主要属性中,还有注册的时候inbox会添加一个message=OnStart,用于触发调用endpoint.onStart方法
(3) NettyStreamManager是专门用来处理文件,jar包,目录等数据流
(4)transportContext对象中主要包含以下属性:
private final TransportConf conf; // 传输配置
private final RpcHandler rpcHandler; // 消息处理对象,比如将字节流转为RequestMessage对象
private final MessageEncoder encoder; // 消息编码
private final MessageDecoder decoder; // 消息解码
【2】如果config.clientMode==false,将会调用nettyEnv.startServer(actualPort)启动服务端【在获取actualPort有一些特殊处理,如果指定的端口被占用,会尝试获取新的端口】
NettyRpcEnv.class
----------------------
/**
* 开启服务
* @param port 服务端口
*/
def startServer(port: Int): Unit = {
val bootstraps: java.util.List[TransportServerBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager))
} else {
java.util.Collections.emptyList()
}
// 启动通信服务
server = transportContext.createServer(host, port, bootstraps)
// 注册校验endpoint,可以参考master endpoint,这里不做进一步分析
dispatcher.registerRpcEndpoint(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
上面代码先分析一下:transportContext.createServer(host, port, bootstraps)
最终会调用new TransportServer(this, host, port, rpcHandler, bootstraps)
所以我们来看一下TransportServer实例化过程:
public TransportServer(
TransportContext context,
String hostToBind,
int portToBind,
RpcHandler appRpcHandler,
List<TransportServerBootstrap> bootstraps) {
this.context = context;
this.conf = context.getConf();
this.appRpcHandler = appRpcHandler;
this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
try {
//初始化netty服务
init(hostToBind, portToBind);
} catch (RuntimeException e) {
JavaUtils.closeQuietly(this);
throw e;
}
}
在初始化netty服务有很多操作,但是这些都是netty server创建最基础的代码,所以不多介绍,唯一要强调的是这段代码
private void init(String hostToBind, int portToBind) {
......
// 添加消息处理器
bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
// 这里的appRpcHandler根据前面的调用可以知道是NettyRpcHandler对象***
RpcHandler rpcHandler = appRpcHandler;
// 在rpcHandler上封装多层处理,比如Sasl认证(装饰器模式)
for (TransportServerBootstrap bootstrap : bootstraps) {
rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
}
// 初始化消息处理器(在messageDecoder之后处理)
context.initializePipeline(ch, rpcHandler);
}
});
接着分析一下context.initializePipeline
这个方法用于channel.pipeline()中添加一个handler
TransportContext.java
--------------------------
public TransportChannelHandler initializePipeline(
SocketChannel channel,
RpcHandler channelRpcHandler) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", encoder)
// 用于处理粘包拆包
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
// 消息解码
.addLast("decoder", decoder)
// 当连接的空闲时间(读或者写)太长时,将会触发一个 IdleStateEvent 事件。然后,你可以通过你的 ChannelInboundHandler 中重写 userEventTrigged 方法来处理该事件。
// 所以TransportChannelHandler中添加了userEventTrigged方法
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
// would require more logic to guarantee if this were not part of the same event loop.
.addLast("handler", channelHandler);
return channelHandler;
} catch (RuntimeException e) {
logger.error("Error while initializing Netty pipeline", e);
throw e;
}
}
通过channel.pipeline()的添加流程,可以使用下图表示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jfH10Lil-1574436254221)(9B03E550A367455980F72BA3AD214F42)]
(IdleStateHandler处理省略)
为了方便理解,介绍其中比较重要的三个handler: MessageEncoder,MessageDecoder,TransportChannelHandler
(1) MessageEncoder : 其encode方法可知,编码器会提取Message对象的type和data等信息编码成一个字节数组
(2) MessageDecoder: 它和MessageEncoder正好相反,会将字节数组转为Message对象
private Message decode(Message.Type msgType, ByteBuf in) {
switch (msgType) {
case ChunkFetchRequest:
return ChunkFetchRequest.decode(in);
case ChunkFetchSuccess:
return ChunkFetchSuccess.decode(in);
......
(3)TransportChannelHandler:
@Override
public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
// 区分消息类型
if (request instanceof RequestMessage) {
requestHandler.handle((RequestMessage) request); // 具体处理查看:TransportRequestHandler.handle
} else {
responseHandler.handle((ResponseMessage) request); // 具体处理查看:TransportResponseHandler.handle
}
}
TransportChannelHandler在处理消息过程中主要区分了三种消息类型:RPC消息、ChunkFetch消息以及Stream消息。
- RPC消息用于抽象所有spark中涉及到RPC操作时需要传输的消息,通常这类消息很小,一般都是些控制类消息
- ChunkFetch消息用于抽象所有spark中涉及到数据拉取操作时需要传输的消息,它用于shuffle数据以及RDD Block数据传输
- Stream消息很简单,主要用于driver到executor传输jar、file文件等
这里着重介绍RPC消息的处理
(1)RequestMessage处理
TransportRequestHandler.java
------------------------
/**
* 处理rpc请求
* @param req
*/
private void processRpcRequest(final RpcRequest req) {
try {
// rpcHandler=NettyRpcHandler
// 核心逻辑: rpcHandler.receive
rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
}
@Override
public void onFailure(Throwable e) {
respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
}
});
} catch (Exception e) {
logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
} finally {
req.body().release();
}
}
NettyRpcHandler.java
---------------------------
// 接收并处理请求
override def receive(
client: TransportClient,
message: ByteBuffer,
callback: RpcResponseCallback): Unit = {
// ByteBuffer => requestMessage
val messageToDispatch = internalReceive(client, message)
// 分发请求信息,前面已经介绍了,最后会放到receivers触发endpoint请求处理
dispatcher.postRemoteMessage(messageToDispatch, callback)
}
(2)ResponseMessage处理
public void handle(ResponseMessage message) throws Exception {
....
else if (message instanceof RpcResponse) {
// 处理RpcResponse类型
RpcResponse resp = (RpcResponse) message;
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
if (listener == null) {
logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
resp.requestId, remoteAddress, resp.body().size());
} else {
outstandingRpcs.remove(resp.requestId);
try {
// 通过listener发送成功信息
listener.onSuccess(resp.body().nioByteBuffer());
} finally {
resp.body().release();
}
}
} else if (message instanceof RpcFailure) {
// 处理RpcFailure类型
RpcFailure resp = (RpcFailure) message;
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
if (listener == null) {
logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
resp.requestId, remoteAddress, resp.errorString);
} else {
outstandingRpcs.remove(resp.requestId);
// 通过listener发送失败信息
listener.onFailure(new RuntimeException(resp.errorString));
}
}
2.2 rpc客户端实现
在前面的master启动分析中,我们分析过这行代码// 向Master的通信终端发送请求,获取绑定的端口号 val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest)
,它的作用就是通过拿到master endpint的引用请求master的服务,也就是以master client的形式请求数据。上次我们介绍到了请求时分两种情况,其一remoteAddr == address(本机)相当于直接往inbox中添加requestMessage这种比较简单,现在来介绍第二种形式,如果是远程服务端怎么处理呢?
首先会创建一个RpcOutboxMessage对象,然后将它添加到outbox中,如果本地已经创建了和远程服务端的连接直接请求
NettyRpcEnv.scala
---------------------
// 封装rpc请求对象
val rpcMessage = RpcOutboxMessage(serialize(message),
onFailure,
(client, response) => onSuccess(deserialize[Any](client, response)))
//
postToOutbox(message.receiver, rpcMessage)
/**
* 添加发送消息到outbox中
* @param receiver
* @param message
*/
private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
if (receiver.client != null) {
// 如果有接收端的连接,直接发送数据。第一次的时候receiver.client=null
message.sendWith(receiver.client)
} else {
require(receiver.address != null,
"Cannot send message to client endpoint with no listen address.")
// 一个远程服务地址对应一个outbox
val targetOutbox = {
// 查找是不是保存过这个client对应的outbox
val outbox = outboxes.get(receiver.address)
if (outbox == null) {
// 如果没有对应的outbox,创建一个
val newOutbox = new Outbox(this, receiver.address)
val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox)
if (oldOutbox == null) {
newOutbox
} else {
oldOutbox
}
} else {
outbox
}
}
if (stopped.get) {
// It's possible that we put `targetOutbox` after stopping. So we need to clean it.
// 从outbox集合中移除,并停止
outboxes.remove(receiver.address)
targetOutbox.stop()
} else {
// 发送消息到接收端
targetOutbox.send(message)
}
}
}
targetOutbox.send发送消息代码,比较简单就是判断当前outbox是否已经停止。
OutBox.scala
--------------------------
def send(message: OutboxMessage): Unit = {
val dropped = synchronized {
if (stopped) {
true
} else {
// 添加消息到outboxMessage集合中
messages.add(message)
false
}
}
if (dropped) {
message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
} else {
// 处理outbox消息
drainOutbox()
}
}
真正的处理消息逻辑就是从drainOutbox方法开始的,现在我们看一下它的具体实现过程:如果如果client没有存在需要创建(重点),如果已经存在就将现有的所有消息,全部发送到远程服务端。
OutBox.scala
----------------------
private def drainOutbox(): Unit = {
var message: OutboxMessage = null
synchronized {
if (stopped) {
return
}
if (connectFuture != null) {
// 如果有一个连接在处理,直接返回
// We are connecting to the remote address, so just exit
return
}
if (client == null) {
// There is no connect task but client is null, so we need to launch the connect task.
launchConnectTask()
return
}
if (draining) {
// There is some thread draining, so just exit
return
}
message = messages.poll()
if (message == null) {
return
}
// 正在处理
draining = true
}
// 一直消费到,outbox队列中没有数据
while (true) {
try {
val _client = synchronized { client }
if (_client != null) {
// 发送到接收端
message.sendWith(_client)
} else {
assert(stopped == true)
}
} catch {
case NonFatal(e) =>
handleNetworkFailure(e)
return
}
synchronized {
if (stopped) {
return
}
// 再获取一条消息
message = messages.poll()
if (message == null) {
// 如果没有消息直接返回
draining = false
return
}
}
}
}
由于我们探究的是rpc的客户端,所以需要重点关注一下launchConnectTask方法。
OutBox.scala
----------------------
private def launchConnectTask(): Unit = {
connectFuture = nettyEnv.clientConnectionExecutor.submit(new Callable[Unit] {
override def call(): Unit = {
try {
// 创建一个连接address的客户端
val _client = nettyEnv.createClient(address)
outbox.synchronized {
client = _client
if (stopped) {
closeClient()
}
}
} catch {
case ie: InterruptedException =>
// exit
return
case NonFatal(e) =>
outbox.synchronized { connectFuture = null }
handleNetworkFailure(e)
return
}
outbox.synchronized { connectFuture = null }
// It's possible that no thread is draining now. If we don't drain here, we cannot send the
// messages until the next message arrives.
// 创建完成之后,重新消费
drainOutbox()
}
})
}
上面的核心方法就是这句:nettyEnv.createClient(address)
,接下来会比较麻烦,请做好准备,nettyEnv.createClient最终会调用TransportClientFactory.createClient方法。这里主要使用判断是不是存在远程服务缓存,如果有直接返回,如果没有就使用TransportClientFactory.createClient(resolvedAddress)的方式创建。而创建逻辑和服务端的非常相似,所以直接看注释就可以了。
TransportClientFactory.java
-----------------------------
public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
// 将host,port封装成InetSocketAddress对象
final InetSocketAddress unresolvedAddress =
InetSocketAddress.createUnresolved(remoteHost, remotePort);
// 判断连接池中是不是存在和该远程服务器的连接
ClientPool clientPool = connectionPool.get(unresolvedAddress);
if (clientPool == null) {
// 创建连接池
connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
clientPool = connectionPool.get(unresolvedAddress);
}
// 从缓存连接池中随机获取一个连接
int clientIndex = rand.nextInt(numConnectionsPerPeer);
TransportClient cachedClient = clientPool.clients[clientIndex];
// 如果连接是有效的
if (cachedClient != null && cachedClient.isActive()) {
// 获取TransportChannelHandler的传输处理器
TransportChannelHandler handler = cachedClient.getChannel().pipeline()
.get(TransportChannelHandler.class);
synchronized (handler) {
// 更新最后一次处理时间
handler.getResponseHandler().updateTimeOfLastRequest();
}
if (cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}",
cachedClient.getSocketAddress(), cachedClient);
return cachedClient;
}
}
// 如果缓存连接池中不存在与该远程服务器的连接,需要重新创建一个
final long preResolveHost = System.nanoTime();
final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000;
// 判断nds解析时间是不是超时(最终改为配置)
if (hostResolveTimeMs > 2000) {
logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
} else {
logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
}
// 更新随机位置的客户端连接对象
synchronized (clientPool.locks[clientIndex]) {
cachedClient = clientPool.clients[clientIndex];
if (cachedClient != null) {
if (cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
return cachedClient;
} else {
logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
}
}
// 创建远程连接对象,并修改对应位置的客户端连接池
clientPool.clients[clientIndex] = createClient(resolvedAddress);
return clientPool.clients[clientIndex];
}
}
/**
* 创建rpc客户端的真正代码
*/
private TransportClient createClient(InetSocketAddress address) throws IOException {
logger.debug("Creating new connection to " + address);
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup)
.channel(socketChannelClass)
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
.option(ChannelOption.ALLOCATOR, pooledAllocator);
final AtomicReference<TransportClient> clientRef = new AtomicReference<>();
final AtomicReference<Channel> channelRef = new AtomicReference<>();
// 添加处理,最终调用和服务端的处理器一样的方法
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
TransportChannelHandler clientHandler = context.initializePipeline(ch);
// 修改最新的客户端引用,方便到匿名内部类之外调用
clientRef.set(clientHandler.getClient());
channelRef.set(ch);
}
});
long preConnect = System.nanoTime();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new IOException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
}
TransportClient client = clientRef.get();
Channel channel = channelRef.get();
assert client != null : "Channel future completed successfully with null client";
long preBootstrap = System.nanoTime();
logger.debug("Connection to {} successful, running bootstraps...", address);
// 使用装饰器模式添加多个处理器逻辑
try {
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
clientBootstrap.doBootstrap(client, channel);
}
} catch (Exception e) {
long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
client.close();
throw Throwables.propagate(e);
}
long postBootstrap = System.nanoTime();
logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
return client;
}
参考文章:https://www.cnblogs.com/xia520pi/p/8693966.html