博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tensorflow 多任务学习 概念介绍
阅读量:4262 次
发布时间:2019-05-26

本文共 3734 字,大约阅读时间需要 12 分钟。

建立多任务图

多任务的一个特点是单个tensor输入(X),多个输出(Y_1,Y_2...)。因此在定义占位符时要定义多个输出。同样也需要有多个损失函数用于分别计算每个任务的损失。具体代码如下:

#  GRAPH CODE# ============# 导入 Tensorflowimport Tensorflow as tf# ======================# 定义图# ======================# 定义占位符X = tf.placeholder("float", [10, 10], name="X")Y1 = tf.placeholder("float", [10, 20], name="Y1")Y2 = tf.placeholder("float", [10, 20], name="Y2")# 定义权重initial_shared_layer_weights = np.random.rand(10,20)initial_Y1_layer_weights = np.random.rand(20,20)initial_Y2_layer_weights = np.random.rand(20,20)shared_layer_weights = tf.Variable(initial_shared_layer_weights, name="share_W", dtype="float32")Y1_layer_weights = tf.Variable(initial_Y1_layer_weights, name="share_Y1", dtype="float32")Y2_layer_weights = tf.Variable(initial_Y2_layer_weights, name="share_Y2", dtype="float32")# 使用relu激活函数构建层shared_layer = tf.nn.relu(tf.matmul(X,shared_layer_weights))Y1_layer = tf.nn.relu(tf.matmul(shared_layer,Y1_layer_weights))Y2_layer = tf.nn.relu(tf.matmul(shared_layer,Y2_layer_weights))# 计算lossY1_Loss = tf.nn.l2_loss(Y1-Y1_layer)Y2_Loss = tf.nn.l2_loss(Y2-Y2_layer)

用图表示出来大概是这样的:

shared

Shared_layer的输出分别作为Y1、Y2的输入,并分别计算loss。


训练

有了网络的构建,接下来是训练。有两种方式:

  1. 交替训练
  2. 联合训练

下面分别讲一下这两种方式。

交替训练

这次先放图,更容易理解: 

Alternate
选择训练需要在每个loss后面接一个优化器,这样就意味着每一次的优化只针对于当前任务,也就是说另一个任务是完全不管的。

# 优化器Y1_op = tf.train.AdamOptimizer().minimize(Y1_Loss)Y2_op = tf.train.AdamOptimizer().minimize(Y2_Loss)

在训练上面我一开始也有些疑惑,首先是feed数据上面的,是否还需要同时把两个标签的数据都输入呢?后来发现的却需要这样,那么就意味着另一任务还是会进行正向传播运算的。

# Calculation (Session) Code# ==========================# open the sessionwith tf.Session() as session:    session.run(tf.initialize_all_variables())    for iters in range(10):        if np.random.rand() < 0.5:            _, Y1_loss = session.run([Y1_op, Y1_Loss],                            {                              X: np.random.rand(10,10)*10,                              Y1: np.random.rand(10,20)*10,                              Y2: np.random.rand(10,20)*10                              })            print(Y1_loss)        else:            _, Y2_loss = session.run([Y2_op, Y2_Loss],                            {                              X: np.random.rand(10,10)*10,                              Y1: np.random.rand(10,20)*10,                              Y2: np.random.rand(10,20)*10                              })            print(Y2_loss)

由此看来这种方法效率还是有点低。

联合训练

两个优化器需要分别训练,我们把他俩联合在一起,不就可以同时训练了吗? 

原理很简单,把两个loss相加即可。得到的图是这样的: 
joint
代码:

# 计算LossY1_Loss = tf.nn.l2_loss(Y1-Y1_layer)Y2_Loss = tf.nn.l2_loss(Y2-Y2_layer)Joint_Loss = Y1_Loss + Y2_Loss# 优化器Optimiser = tf.train.AdamOptimizer().minimize(Joint_Loss)Y1_op = tf.train.AdamOptimizer().minimize(Y1_Loss)Y2_op = tf.train.AdamOptimizer().minimize(Y2_Loss)# 联合训练# Calculation (Session) Code# ==========================# open the sessionwith tf.Session() as session:    session.run(tf.initialize_all_variables())    _, Joint_Loss = session.run([Optimiser, Joint_Loss],                    {                      X: np.random.rand(10,10)*10,                      Y1: np.random.rand(10,20)*10,                      Y2: np.random.rand(10,20)*10                      })    print(Joint_Loss)

这是原文的代码,其中定义的Y1_opY2_op并没有使用,应该是多此一举了。

如何选择?

什么时候交替训练好?

Alternate training is a good idea when you have two different datasets for each of the different tasks (for example, translating from English to French and English to German). By designing a network in this way, you can improve the performance of each of your individual tasks without having to find more task-specific training data.

当对每个不同的任务有两个不同的数据集(例如,从英语翻译成法语,英语翻译成德语)时,交替训练是一个好主意。通过以这种方式设计网络,可以提高每个任务的性能,而无需找到更多任务特定的训练数据。

这里的例子很好理解,但是“数据集”指的应该不是输入数据X。我认为应该是指输出的结果Y_1、Y_2关联不大。

什么时候联合训练好?

交替训练容易对某一类产生偏向,当对于相同数据集,产生不同属性的输出时,保持任务的独立性,使用联合训练较好。


这两种方式在实际中也成功实现了,不过目前准确率还不是很高,有待改进。

 

你可能感兴趣的文章
两个相交的单向链表求交点
查看>>
归并排序
查看>>
寻找无序数组中的全部降序对
查看>>
寻找字符串里第一个只出现过一次的字符
查看>>
消息id乱序接收但顺序发送问题
查看>>
数组最大连续乘积
查看>>
三个非比较排序(线性排序)
查看>>
把奇/偶数(或某种特征的数)都放在数组左边问题
查看>>
海里数据topk
查看>>
字符串左右旋
查看>>
二叉树按层遍历并按层打印和蛇形打印
查看>>
二叉树打印节点和为某值的全部路径
查看>>
打印普通二叉树最大搜索子树
查看>>
bitmap用途
查看>>
LRUCache
查看>>
布隆过滤器
查看>>
Hash总览
查看>>
关于redis
查看>>
排序总览
查看>>
关于c++的class(继承、重载、隐藏)
查看>>