博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tf中的NCE-loss实现学习【转载】
阅读量:6095 次
发布时间:2019-06-20

本文共 2453 字,大约阅读时间需要 8 分钟。

转自:

 1.tf中的nce_loss的API

def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,             num_true=1,             sampled_values=None,             remove_accidental_hits=False,             partition_strategy="mod",             name="nce_loss")

假设nce_loss之前的输入数据是K维的,一共有N个类,那么

  • weight.shape = (N, K)
  • bias.shape = (N)
  • inputs.shape = (batch_size, K)
  • labels.shape = (batch_size, num_true)
  • num_true : 实际的正样本个数
  • num_sampled: 采样出多少个负样本
  • num_classes = N
  • sampled_values: 采样出的负样本,如果是None,就会用不同的sampler去采样。待会儿说sampler是什么。
  • remove_accidental_hits: 如果采样时不小心采样到的负样本刚好是正样本,要不要干掉
  • partition_strategy:对weights进行时并行查表时的策略。TF的embeding_lookup是在CPU里实现的,这里需要考虑多线程查表时的锁的问题。

nce_loss的实现逻辑如下:

  • _compute_sampled_logits: 通过这个函数计算出正样本和采样出的负样本对应的output和label
  • sigmoid_cross_entropy_with_logits: 通过 sigmoid cross entropy来计算output和label的loss,从而进行反向传播。这个函数把最后的问题转化为了num_sampled+num_real个两类分类问题,然后每个分类问题用了交叉熵的损伤函数,也就是logistic regression常用的损失函数。TF里还提供了一个softmax_cross_entropy_with_logits的函数,和这个有所区别。

2.tf中word2vec实现

loss = tf.reduce_mean(      tf.nn.nce_loss(nce_weights, nce_biases, embed, train_labels,                     num_sampled, vocabulary_size))

 它这里并没有传sampled_values,那么它的负样本是怎么得到的呢?继续看nce_loss的实现,可以看到里面处理sampled_values=None的代码如下:

if sampled_values is None:      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(          true_classes=labels,          num_true=num_true,          num_sampled=num_sampled,          unique=True,          range_max=num_classes)

所以,默认情况下,他会用log_uniform_candidate_sampler去采样。那么log_uniform_candidate_sampler是怎么采样的呢?他的实现在:

  • 他会在[0, range_max)中采样出一个整数k
  • P(k) = (log(k + 2) - log(k + 1)) / log(range_max + 1)

可以看到,k越大,被采样到的概率越小。那么在TF的word2vec里,类别的编号有什么含义吗?看下面的代码:

 

def build_dataset(words):  count = [['UNK', -1]]  count.extend(collections.Counter(words).most_common(vocabulary_size - 1))  dictionary = dict()  for word, _ in count:    dictionary[word] = len(dictionary)  data = list()  unk_count = 0  for word in words:    if word in dictionary:      index = dictionary[word]    else:      index = 0  # dictionary['UNK']      unk_count += 1    data.append(index)  count[0][1] = unk_count  reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))  return data, count, dictionary, reverse_dictionary

可以看到,TF的word2vec实现里,词频越大,词的类别编号也就越小。因此,在TF的word2vec里,负采样的过程其实就是优先采词频高的词作为负样本。

在提出负采样的, 包括word2vec的。是按照热门度的0.75次方采样的,这个和TF的实现有所区别。但大概的意思差不多,就是越热门,越有可能成为负样本。

转载于:https://www.cnblogs.com/BlueBlueSea/p/10615766.html

你可能感兴趣的文章
wireshark抓包图解 TCP三次握手/四次挥手详解
查看>>
晚上练习一点
查看>>
西安邀请赛-D(带权并查集+背包)
查看>>
jqueryEasyui重新渲染
查看>>
ASP.NET MVC 第二章 路由和修改路由
查看>>
redis
查看>>
生产者与消费者
查看>>
simulate web browser
查看>>
关于volatile(转)
查看>>
maven pom文件标签含义
查看>>
github访问配置
查看>>
一些个人偏好的书籍
查看>>
数据库导出sql
查看>>
C语言的printf输出格式控制
查看>>
修改const变量
查看>>
CCF认证历年试题解【网上跟帖,请不要使用称呼】
查看>>
HDU1576 A/B【扩展欧几里得算法】
查看>>
Ext.net常见问题收集
查看>>
PHP入门part4
查看>>
查看AWR和获取缓存库中的执行计划
查看>>