背景
[作者:DeepLearningStack,阿里巴巴算法工程师,开源 TensorFlow Contributor]
本篇是 TensorFlow 通信机制系列的第二篇文章,主要梳理使用 gRPC 网络传输部分模块的结构和源码。如果读者对 TensorFlow 中 Rendezvous 部分的基本结构和原理还不是非常了解,那么建议先从这篇文章开始阅读。TensorFlow 在最初被开源时还只是个单机的异构训练框架,在迭代到 0.8 版本开始正式支持多机分布式训练。与其他分布式训练框架不同,Google 选用了开源项目 gRPC 作为 TensorFlow 的跨机通信协议作为支持。gRPC 的编程和使用其实是相对复杂的,TensorFlow 为了能让 gRPC 的调用更加平滑,在调用链封装和抽象上面做了较多工作,甚至有些工作例如创建和管理 gRPC channel 涉及到了 GrpcSession 模块。从个人角度来看,利用 gRPC 进行 Tensor 通信的过程已经足够丰富,所以我们只针对 gRPC 传输 Tensor 过程进行梳理,至于涉及到 gRPC 管理方面的内容会在另一篇介绍分布式 Session 创建和管理的文章中集中梳理。
跨进程通信过程
根据之前写博客的经验,直接介绍类图结构和源码部分可能会让人懵圈,还是先从逻辑上把通信过程梳理清楚更能做到深入浅出。其实对于不是非常了解分布式系统或大规模并发系统的读者而言,TensorFlow 中通信过程是有些 “别扭” 的。那么有的读者可能会觉得诧异,跨进程通信过程不就是一方做 Send,另一方做 Recv 吗?这是一个理所当然的过程,为什么会 “别扭” 呢?是的,整个过程依然是一方做 Send,另一方做 Recv。而它的 “别扭” 之处就在于 —— 真正的通信过程由 Recv 方触发,而不是 Send 方!这就是理解 TensorFlow 中使用 gRPC 传输 Tensor 过程的最关键点。
前一篇文章分析过在本地传输的场景下 Tensor 通信的大体过程,从机制和逻辑上来说,跨进程传输过程和本地传输没有很大的差异:TensorFlow 使用 Rendezvous 通信 Tensor,借助一个类似 Table 的数据结构作为传输的中转,并且 Send 方和 Recv 方依靠 ParsedKey 这一唯一传输标识符,跨进程通信也是如此。如果读者对这部分内容不了解,可以参考这篇文章。
Send 方 —— 将 Ready 的 Tensor 挂入本地 Table
和本地传输场景下的 Send 过程相同,本地 Tensor 处于 Ready 状态后就被放挂了本地 Worker 的 Table 中,至此 Send 过程就全部完成了。所以 Send 过程完全没有涉及到任何跨网络传输的内容,并且 Send 过程是非阻塞的。
Recv 方 —— 向 Send 方主动发出请求,触发通信过程
Recv 方是 Tensor 的接收方,它的处理过程是:将所需要的 Tensor 对应的 ParsedKey 拼出后,主动向 Send 方主动发出 Request,Send 方在接收到 Request 后立即在本地 Table 中查找方所需要的 Tensor,找到后将 Tensor 封装成 Response 发送回 Recv 方。在这个过程中,Recv 方可以认为是 Client,Send 方可以认为是 Server,通过发送 Request 和 Response 来完成 Tensor 的传输。
结构设计解析
建议读者在阅读本节时适当翻开 TensorFlow C++ 部分源码,但只需要理解结构关系即可(比如类之间的继承、组合、依赖关系),暂时不要阅读类的实现内容。因为 RemoteRendezvous 部分涉及到的类结构非常多,直接陷入细节的阅读会深陷其中不能自拔,甚至弄得一头雾水十分疲惫。在梳理结构时一边参照下文中的类图结构,一边从设计模式和架构的角度尝试去理解每个模块的司职是理解本篇细节的关键。先理解宏观结构看懂架子,再去深入理解实现细节尝试去优化是读任何代码的正确顺序。
任何场景下,通信过程几乎都是可以通过简单的图将功能描述清楚的。但是不可否认的是,任何涉及到分布式通信的系统在架构上都会对通信层做相对复杂的封装。一方面是因为通信虽然功能简单,但其实现本身具有相对较高的复杂性(大家可以尝试阅读 gRPC 源码感受下底层软件的复杂度)。另一方面,应用层也需要与通信底层通过抽象尽量实现较好的解耦,这样也方便将应用层模块被其他团队扩展编写。下面我们一起来探究 TensorFlow 中涉及到跨进程通信的 Rendezvous 系列。
两层抽象继承关系 ——RemoteRendezvous 与 BaseRemoteRendezvous
前一篇在介绍本地传输时我们熟悉了 Rendezvous 模块中与本地传输相关的类,例如 LocalRendezvousImpl,IntraProcessRendezvous 和 SimpleRendezvous。对应地,跨进程传输也有不同的 Rendezvous,从根源上来说,它们也继承于 Rendezvous 接口,并且不同的传输协议也有各自的 Rendezvous。在这里,我们再次将前文中展示的总体类结构图展示出来,这次我们将涉及到远程传输的类用特殊颜色标出,如下图所示。
综合来看,从 Rendezvous 的继承结构来看,涉及到跨进程传输的 Rendezvous 有层:
1. RemoteRendezvous:只增加了一个 Initialize 方法,并标记为纯虚函数。这是因为跨进程 Rendezvous 需要借助 Session 做一些初始化工作,所以 TensorFlow 中所有涉及到跨进程通信的 Rendezvous 都需要重写 Initialize 函数,使用前也必须强制调用该函数。
2. 各种具体协议 Rendezvous 的基类 ——BaseRemoteRendezvous:既然所有涉及跨进程通信的 Rendezvous 都需要提供各自协议下实现的 Initialize 函数,那么没有比在 RemoteRendezvous 和真正特化的 Rendezvous 之间再添加一层继承关系更合适的做法了。事实上 TensorFlow 在此处也是这么设计的,这个承上启下的类就是 BaseRemoteRendezvous。它还提供了公共的 Send 和 Recv 方法,这可以让继承它的特化 Rendezvous 尽最大可能做到代码复用。
BaseRecvTensorCall 是通信的实体抽象,后面分析时会有更深的体会,在这里先有个印象即可。
开始特化 —— 各种各样的 RemoteRendezvous
TensorFlow 目标是通用可扩展,所以被设计成允许底层支持多种通信协议的结构。事实上到目前为止,算上 contrib 目录的内容(contrib 目录是广大 TensorFlow 贡献者添加的内容),TensorFlow 已经支持包括 gRPC,RDMA(Remote Direct Memroy Access),GDR(GPU Dirrect)和 MPI 四种通信协议,因此包含了四种对应的 Rendezvous,他们分别是 RpcRemoteRendezvous,RDMARemoteRendezvous,GdrRemoteRendezvous 和 MPIRemoteRendezvous。每种通信协议各有其特点,有时候其可用性也取决于硬件和软件条件(比如 RDMA 需要支持 RDMA 协议的网卡,通常跑在 Infiniband 和 RoCE 网络上,如果没有硬件支持,那么 RDMA 将无法使用,GDR 也是这个道理)。从代码中可以看出,实现每种具体的 RemoteRendezvous 都有一定的复杂性,所以很难想象在没有封装抽象和代码复用的结构里如何实现这些内容。在本篇我们关注 RpcRemoteRendezvous,它是 gRPC 协议实现的 RemoteRendezvous。
令人熟悉的管理器模式 ——RendezvousMgr
为了更好地管理 RemoteRendezvous,TensorFlow 设计了相应的管理器 ——RendezvousMgr 相关类,并为每种具体的 RemoteRendevzous 做了特化。熟悉设计模式的读者都知道,管理器是一种经典的设计模式,它能使管理职责的变化独立于类本身。RendezvousMgr 主要负责 RemoteRendezvous 的创建和销毁,它也定义了两个本地版本的 Recv 接口。有的读者可能会问,管理器为什么还允许做 Recv?并且只能做本地的 Recv?我个人判断添加这两个接口纯粹是为了方便某些地方的使用。至于 RendezvousMgr 的创建时机和 RemoteRendezvous 的初始化过程并不是本篇解析的范畴,因为这涉及到分布式场景下创建 Server 的较长链路,这部分内容会在以后的博客中详细解析。下面是 RendezvousMgr 相关的类图结构,我们可以看到其接口类中已经定义了 Recv 接口。
RpcRemoteRendezvous 通信过程与源码解析
上一小节中对 RemoteRendezvous 相关类结构和类间的关系做了解析,旨在从架构层面帮助读者理解各个类的职能。虽然涉及到的内容比较多,但是整体的结构和逻辑还是非常清晰的。如果读者尝试通过阅读源码辅助理解上述内容之后仍然感觉有些眼花缭乱,没有关系,我们在这里暂时做一个简单地梳理,将重点内容梳理到以下几条。
1. 本地 Rendezvous 和 RemoteRendezvous 共同继承了同一个接口;
2. RemoteRendezvous 需要支持不同的通信协议,因此派生了各种各样的实现类;
3. RemoteRendezvous 的使用较为复杂,为此引入了管理器模式 ——RendezvousMgr,它负责 RemoteRendezvous 的创建和销毁,并添加了两个额外的 Recv 接口方便某些场景直接调用;
4. RemoteRendezvous 做了两层继承结构只是为了添加一个 Initialize 方法。
本篇我们梳理使用 gRPC 协议的部分,从上文中梳理的结构中不难看出,这部分涉及到的类并不多。
1. Rendezvous 相关类 ——RemoteRendezvous,BaseRemoteRendezvous,RpcRemoteRendezvous;
2. 管理器 ——BaseRendezvousMgr,RpcRendezvousMgr
3. 其他类 ——BaseRecvTensorCall,RpcRecvTensorCall 和 DefferedCall
毕竟是涉及到了 gRPC 协议本身的使用,所以有必要在梳理源码之前从宏观上对 gRPC 的工作流程做一个简单地梳理。
gRPC 编程中的代理模式 ——Stub 与 Service
在此我们假设同学们对 gRPC 的原理和使用有一些基本的了解,比如需要使用 Protobuf 预先定义 Service 接口,并且区分 Stub 和 Service 等。对此不了解的同学还是建议先认真阅读一下 gRPC 的使用文档和范例,下面这段文字只对 gRPC 做一个非常简单的描述。
在一次 RPC 调用中,客户端需要调用服务端的服务,然后将处理结果返回给客户端。而 gRPC 做到了 “让客户端调用远端函数时就像调用本地函数一样” 的体验,这得益于一种经典的设计模式 —— 代理模式。负责为客户端代理的节点(gRPC 中称之为 Stub)会将请求和参数传到服务端,并由 Service 进行实际的处理,然后将结果返回给 Stub,最终返回到客户端中。我们甚至可以认为负责代理的 Stub 就是客户端,因为它的职责就是与远端交互并取得结果。另外,为了能够让传输量尽可能少,也为了能够让传输不受客户端和服务端具体的类型限制,gRPC 在做跨网络传输前将消息统一序列化成 Protobuf 格式。下图是从 gRPC 官网教程中摘出的工作原理图。
Send 过程
因为 Send 过程并不涉及跨进程传输,只是将 Ready 的 Tensor 挂入本地 Table 之中,所以它和 LocalRendezvousImpl 的 Send 完全相同。不仅如此,TensorFlow 中的任何 RemoteRendezvous 的 Send 过程都要遵循这样的原理,基于代码复用的考虑,将这部分内容都被抽象到了公共基类 BaseRemoteRendezvous 的 Send 函数里是一个很好的设计。事实上,BaseRemoteRendezvous 的 Send 过程就是调用了 LocalRendezvousImpl 的 Send 过程,所以 LocalRendezvousImpl 必须要作为 BaseRemoteRendezvous 的成员之一。下面的代码展示了这一过程。
1 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
2 const Rendezvous::Args& args,
3 const Tensor& val, const bool is_dead) {
4 VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
5 {
6 mutex_lock l(mu_);
7 if (!status_.ok()) return status_;
8 DCHECK(is_initialized_locked());
9 if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
10 return errors::InvalidArgument(
11 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
12 session_->worker_name);
13 }
14 }
15 // Buffers "val" and "device_context" in local_.
16 return local_->Send(parsed, args, val, is_dead);
17 }
Recv 过程
Recv 过程就非常复杂了,因为每种 RemoteRendezvous 都涉及到不同的通信协议以及管理方式,所以 Recv 函数是真正需要继承重写的模块。在看 RpcRemoteRendezvous 具体的实现之前,我们必须先将 gRPC 定义服务的接口部分梳理清楚。
gRPC 的服务定义接口文件
在 TensorFlow 的 core/protobuf 文件中,我们需要研究一下 worker_service.proto 文件,这个文件中定义了若干 RPC Service 接口。
虽然它定义了很多 RPC 服务接口,但是我们只需要关注和 Tensor 接收相关的接口定义即可。准确地说,目前我们必须要知道的是下面这个 Service 定义。
// See worker.proto for details.
rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {
// RecvTensor Method
}
显然,这是一个让服务端处理 “接收 Tensor” 的服务(注意是让服务端处理名为 “接收 Tensor” 的服务,而不是让服务端去接收 Tensor。因为客户端有接收 Tensor 的需求,但需要服务端发送 Tensor,为客户端发送 Tensor 的服务被称之为 “接收 Tensor”),按照注释提示,我们可以在 worker.proto 中找到 RecvTensorRequest 和 RecvTensorResponse 的数据结构,这部分结构读者可以自己查阅,非常容易理解。在编译时,扩展的 Protobuf 编译器会对 worker_service.proto 中的 rpc 接口生成 C++ 服务接口代码和 Stub 代码(毕竟 Stub 代码比较纯粹并且和业务逻辑无关,它只是一个向对应 Service 端发送处理请求的过程),TensorFlow 只需要对具体的 Service 提供实现即可。
与 gRPC 生成的代码联系起来
gRPC 会为 worker_service.proto 中每一个 rpc 服务生成 C++ 接口代码,为了区分多个 rpc 服务,特意为每个服务生成了特殊的名字。比如 RecvTensor 服务的名字就是 /tensorflow.WorkerService/RecvTensor。为了不直接使用冗长的字符串,TensorFlow 为 worker_service.proto 中的每个服务都做了 enumeration 的映射,这部分代码在 tensorflow/core/distributed_runtime/grpc_worker_service_impl.h 和同名实现文件中。
1 // Names of worker methods.
2 enum class GrpcWorkerMethod {
3 kGetStatus,
4 kCreateWorkerSession,
5 kDeleteWorkerSession,
6 kRegisterGraph,
7 kDeregisterGraph,
8 kRunGraph,
9 kCleanupGraph,
10 kCleanupAll,
11 kRecvTensor,
12 kRecvBuf,
13 kLogging,
14 kTracing,
15 kCompleteGroup,
16 kCompleteInstance,
17 kGetStepSequence,
18 };
下面是从 enumeration 类型映射到具体字符串的函数。
1 const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
2 switch (id) {
3 case GrpcWorkerMethod::kGetStatus:
4 return "/tensorflow.WorkerService/GetStatus";
5 case GrpcWorkerMethod::kCreateWorkerSession:
6 return "/tensorflow.WorkerService/CreateWorkerSession";
7 case GrpcWorkerMethod::kDeleteWorkerSession:
8 return "/tensorflow.WorkerService/DeleteWorkerSession";
9 case GrpcWorkerMethod::kRegisterGraph:
10 return "/tensorflow.WorkerService/RegisterGraph";
11 case GrpcWorkerMethod::kDeregisterGraph:
12 return "/tensorflow.WorkerService/DeregisterGraph";
13 case GrpcWorkerMethod::kRunGraph:
14 return "/tensorflow.WorkerService/RunGraph";
15 case GrpcWorkerMethod::kCleanupGraph:
16 return "/tensorflow.WorkerService/CleanupGraph";
17 case GrpcWorkerMethod::kCleanupAll:
18 return "/tensorflow.WorkerService/CleanupAll";
19 case GrpcWorkerMethod::kRecvTensor:
20 return "/tensorflow.WorkerService/RecvTensor";
21 case GrpcWorkerMethod::kRecvBuf:
22 return "/tensorflow.WorkerService/RecvBuf";
23 case GrpcWorkerMethod::kLogging:
24 return "/tensorflow.WorkerService/Logging";
25 case GrpcWorkerMethod::kTracing:
26 return "/tensorflow.WorkerService/Tracing";
27 case GrpcWorkerMethod::kCompleteGroup:
28 return "/tensorflow.WorkerService/CompleteGroup";
29 case GrpcWorkerMethod::kCompleteInstance:
30 return "/tensorflow.WorkerService/CompleteInstance";
31 case GrpcWorkerMethod::kGetStepSequence:
32 return "/tensorflow.WorkerService/GetStepSequence";
33 }
34 // Shouldn't be reached.
35 LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
36 return "invalid id";
37 }
另外,还需要为每个 RPC 服务注册为异步服务,这需要使用 gRPC 自带的 AddMethod 接口和 MarkMethodAsync 接口,如下所示。
1 WorkerService::AsyncService::AsyncService() {
2 for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
3 AddMethod(new ::grpc::internal::RpcServiceMethod(
4 GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
5 ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
6 ::grpc::Service::MarkMethodAsync(i);
7 }
8 }
好了,接下来就是解析源码中具体的交互过程了。其实 TensorFlow 在框架层面对 gRPC 的使用了一些 Best Practice,比如异步处理请求的架构和多线程轮询 Completion Queue 等。将这些连在一起梳理需要更多的篇幅,一次性展示大量的内容也不利于阅读,所以我们只对发送和接收过程做一个梳理。
Client 端的调用链
从 BaseRemoteRendeezvous 的 RecvAsync 出发,逐渐深入调用链底层。时序图是分析调用链的最好工具,下面给出了 Client 端到 Stub 的调用过程,这里面涉及到了几个新的类。
1. RpcRecvTensorCall:这是一次 gRPC 调用的抽象,继承了 BaseRecvTensorCall 这个抽象基类,它封装了复杂的后续调用链。
2. GrpcRemoteWorker:它也是 client 端的内容,只不过它是 Remote 端的代理。
3. RpcState:这是真正封装了一次 RPC 调用及状态的类,它会直接对 Stub 以及 GenericClientAsyncResponseReader 进行管理,比如向服务端发送异步请求并等待结果等。
Client 端是一个虚拟角色,它可以是调用 RpcRemoteRendezvous 的任何一个模块。我们可以看到,RpcRemoteRendezvous 的一次 RecvRemoteAsync 过程非常长,并且 Stub 的调用时异步的。这里的代码确实有些多,所以我们只展示一下关键代码段,但是建议读者打开源码仔细阅读每个调用链。
下面是 RecvRemoteAsync 的代码段,主要做了 RpcRecvTensorCall 的初始化,注册以及启动工作。
1 void RpcRemoteRendezvous::RecvFromRemoteAsync(
2 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
3 DoneCallback done) {
4 CHECK(is_initialized());
5 Status s;
6
7 // Prepare a RecvTensor call that can handle being aborted.
8 RpcRecvTensorCall* call = get_call_freelist()->New();
9
10 // key.src_device identifies a remote device.
11 if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
12 &call->src_rel_device_)) {
13 s = errors::Internal(parsed.src_device,
14 " is invalid remote source device.");
15 }
16 WorkerSession* sess = session();
17 WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_);
18 if (s.ok() && rwi == nullptr) {
19 s = errors::Internal("No worker known as ", call->src_worker_);
20 }
21
22 Device* dst_device;
23 if (s.ok()) {
24 s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
25 }
26 if (!s.ok()) {
27 if (rwi != nullptr) {
28 sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
29 }
30 get_call_freelist()->Release(call, sess->worker_cache.get());
31 done(s, Args(), recv_args, Tensor{}, false);
32 return;
33 }
34
35 call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
36 recv_args, std::move(done));
37
38 // Record "call" in active_ so that it can be aborted cleanly.
39 RegisterCall(call);
40
41 // RendezvousMgr already aborted, shouldn't send RPC call any more
42 if (!call->status().ok()) {
43 call->done()(call->status(), Args(), Args(), Tensor(), false);
44 session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
45 call->wi_ = nullptr;
46 get_call_freelist()->Release(call, session()->worker_cache.get());
47 return;
48 }
49
50 // Start "call".
51 Ref();
52 call->Start([this, call]() {
53 // Removes "call" from active_. Prevent StartAbort().
54 DeregisterCall(call);
55 // If StartAbort was called prior to DeregisterCall, then the
56 // current status should be bad.
57 Status s = call->status();
58 call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
59 session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
60 call->wi_ = nullptr;
61 get_call_freelist()->Release(call, session()->worker_cache.get());
62 Unref();
63 });
64 }
下面是 GrpcRemoteWorker 调用 RPCState 的过程,最后的 IssueRequest 即开始创建 RPCState 并触发 stub 的调用。
void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
TensorResponse* response, StatusCallback done) override {
VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
int64 start_usec = Env::Default()->NowMicros();
// Type-specialized logging for this method.
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
StatusCallback wrapper_done;
const StatusCallback* cb_to_use;
if (!logging_active) {
cb_to_use = &done; // No additional work to do, so just use done directly
} else {
wrapper_done = [this, request, response, done, start_usec](Status s) {
if (logger_->LoggingActive()) {
int64 end_usec = Env::Default()->NowMicros();
int64 step_id = request->step_id();
int64 bytes = response->tensor().TotalBytes();
int64 send_start_usec = start_usec;
// If a send start time was reported by the other side, use
// that instead. Maybe we should mark the display if we're using
// our local time instead of the remote start time?
if (response->metadata().send_start_micros()) {
// send_start_micros is the timestamp taken when the
// remote machine began to send the RecvTensor response.
// Due to clock skew between source and dest machines, it
// is possible that send_start_micros can be larger than
// end_usec or less than start_usec.
//
// To respect causality, we enforce the invariants that
// the RecvTensor response can not have been sent before
// the RecvTensor request, and must have been sent before
// it was received.
send_start_usec = std::max(
start_usec,
static_cast<int64>(response->metadata().send_start_micros()));
send_start_usec = std::min(send_start_usec, end_usec - 1);
}
const string& key = request->rendezvous_key();
std::vector<string> key_parts = str_util::Split(key, ';');
if (key_parts.size() != 5) {
LOG(WARNING) << "Bad key: " << key;
} else {
logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
key_parts[3], // tensor name
key_parts[0], // src_device
key_parts[2], // dst_device
bytes);
}
}
VLOG(2) << "done callback, req: " << request->DebugString()
<< " response " << response->metadata().DebugString();
done(s);
};
cb_to_use = &wrapper_done;
}
IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts);
}
最后展示一下 Stub 的触发位置,这个函数在 RPCState 类中,并且在创建 RPCState 对象时立即被调用。
1 void StartCall() {
2 context_.reset(new ::grpc::ClientContext());
3 context_->set_fail_fast(fail_fast_);
4
5 if (timeout_in_ms_ > 0) {
6 context_->set_deadline(
7 gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN));
8 }
9 if (call_opts_) {
10 call_opts_->SetCancelCallback([this]() { context_->TryCancel(); });
11 }
12
13 VLOG(2) << "Starting call: " << method_;
14
15 call_ = std::move(
16 stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_));
17 call_->StartCall();
18 call_->Finish(&response_buf_, &status_, this);
19 }
Server 端负责查找 Tensor 的 Service
如果我们把异步处理请求的架构和多线程轮询 Completion Queue 的 Best Practice 去除,那么 Service 端其实并不复杂,调用链相对 Client 端短了很多,下面的时序图展示了自 Server 端接收请求后的调用过程,这里面也涉及到了几个新的类。
1. GrpcWorkerServiceThread:这是服务端处理请求的线程类。
2. GrpcWorker:这是真正负责处理请求的 Worker,是 GrpcRemoteWorker 的服务端版本;
3. WorkerCall:这是服务端处理一次 gRPC 请求和响应的类,抽象为 WorkerCall,其实这也是个别名,真实的名称较长;
4. ServerAsyncResponseWriter:这是 gRPC 为用户端提供的 Response writer,是承载响应的实体。
5. Utils:这其实不是一个类,而是多个工具的组合,为了在时序图表达方便,统称为 Utils。
可以看出,服务端接收到请求后,会调用 RecvLocalAsync 在本地将客户端所需要的 Tensor 查找出来,然后拷贝到 CPU 上,最后利用 gRPC 发送回客户端。同样,我们展示关键代码段。
下面是 GrpcWorker 调用 RendezvousMgr 的 RecvLocalAsync 为客户端寻找真正 Tensor 的过程。回调函数中能够看出,在找到对应 Tensor 后,需要将 Tensor 做 Encode,然后拷贝到 CPU 端。
1 env_->rendezvous_mgr->RecvLocalAsync(
2 step_id, parsed,
3 [opts, response, done, src_dev, request](
4 const Status& status, const Rendezvous::Args& send_args,
5 const Rendezvous::Args& recv_args, const Tensor& val,
6 const bool is_dead) {
7 opts->ClearCancelCallback();
8 if (status.ok()) {
9 // DMA can only be used for Tensors that do not fall into
10 // the following three odd edge cases: 1) a zero-size
11 // buffer, 2) a dead tensor which has an uninit value, and
12 // 3) the tensor has the on_host allocation attribute,
13 // i.e. it's in CPU RAM *independent of its assigned
14 // device type*.
15 const bool on_host = send_args.alloc_attrs.on_host();
16 {
17 // Non-DMA cases.
18 if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
19 DeviceContext* send_dev_context = send_args.device_context;
20 AllocatorAttributes alloc_attrs;
21 alloc_attrs.set_gpu_compatible(true);
22 alloc_attrs.set_on_host(true);
23 Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
24 Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
25 CHECK(send_dev_context)
26 << "send dev name: " << src_dev->name()
27 << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
28 // "val" is on an accelerator device. Uses the device_context to
29 // fill the copy on host.
30 StatusCallback copy_ready = [response, done, copy,
31 is_dead](const Status& s) {
32 // The value is now ready to be returned on the wire.
33 grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
34 done(s);
35 delete copy;
36 };
37
38 send_dev_context->CopyDeviceTensorToCPU(
39 &val, request->rendezvous_key(), src_dev, copy, copy_ready);
40 } else {
41 grpc::EncodeTensorToByteBuffer(is_dead, val, response);
42 done(Status::OK());
43 }
44 }
45 } else {
46 // !s.ok()
47 done(status);
48 }
49 });
至此,我们的 Rendezvous 之 gRPC 传输之旅就圆满结束了,在阅读本篇时还是希望读者能够在理解结构设计后,对照 C++ 源码仔细阅读反复推敲里面的每一个细节,这样才能有更深的理解。
一个需要思考的问题 ——gRPC 传输 Tensor 很低效?
是的,确实很低效。为什么?从设计哲学上说,gRPC 本身设计并不适合深度学习训练场景。从细节上来说它有以下几个缺陷:
1. gRPC 发送 Tensor 前,接收 Tensor 后必须要做序列化,在 Tensor 很大的时候这是一个非常讨厌的 overhead,发送接收延迟过大;
2. 序列化根本没有对数据做任何压缩,这是因为 Tensor 都是稠密的,所以序列化没有意义;
3. 不能支持 RDMA 和 GPU Direct。虽然这依赖于硬件,但是 gRPC 在软件层面也并没有做这些适配。
所以大部分人使用 TensorFlow 分布式时都会对性能有很大的抱怨,这里面很大的原因和 gRPC 有关。如果你使用 NCCL 或者 MPI,那么你会得到不一样的性能。
总结
本篇文章篇幅较长,是 Rendezvous 机制系列的第二篇,主要梳理了涉及到 gRPC 传输的模块架构设计和源码细节,并且详细梳理了通信过程。理解 TensorFlow 跨机传输的关键在于理解一个事实:真正的通信过程由 Recv 方触发,而不是 Send 方!Send 依然将 Ready 的 Tensor 挂入本地 Table 中,而 Recv 会向 Send 端发送 gRPC 请求查询所需要的 Tensor,然后返回所需要的结果,这个过程虽然有些别扭,但逻辑上并不稀奇。从结构设计上来说,RemoteRendezvous 沿用了 Rendezvous 接口,并且完全复用了 LocalRendezvousImpl 的 Send 代码,而 Recv 由于涉及到具体的通信细节和管理机制,则各有各的不同。另外,RemoteRendezvous 相对 LocalRendezvous 复杂很多,需要管理器进行管理。最后一大部分是 Send 和 Recv 的源码细节展示,因为无论是客户端还是服务端,其调用链都比较长,所以以时序图的形式展示各个类之间的调用关系和协作关系较为清晰,具体每个调用的细节建议读者结合源码逐一分析,并连同本篇文章一起理解较为深刻。最后,我们总结了 gRPC 传输 Tensor 的明显缺陷,当然这也是为性能优化开辟了新的空间。