正文内容:

  • 背景
  • 架构
  • pytorch复现
  • 细节(发现的问题)
  • tensorflow&keras使用初体验
  • 参考资料

背景

 2016年由微软提出,是一个用于推荐系统的深度学习架构。

 deepcrossing模型的应用场景见于广告推荐,即搜索引擎在帮助用户进行信息检索时,可以返回一些广告提供给用户。而考虑用户对不同的广告的点击概率不相同,所以需要对这些广告进行点击率ctr预测,从而投放对于用户而言点击率最高的几个广告。

架构

 deepcrossing的架构很简单,由残差连接,嵌入层等组成。

 对于广告而言,其会有一些数值型特征如广告的费用,广告的图片,也会有一些类别型特征。对于一般的神经网络而言,其只接受数值型的输入,而对于类别型的特征,需要进行数值化处理,如获取其one-hot编码。但是对于一个类别而言,其取值千变多样,那么one-hot编码会导致特征的维度急剧上升,引发维数灾难,既占据存储空间也不利于训练。所以为了解决类别型数据可能导致的维数灾难问题,引入embedding的概念,这个概念在协同过滤中已经学习过,在这里就不过多阐述。大概意思就是使用一个实向量表示某个类别特征中某一个取值。

 Embedding技术可以有效降维。对类别型数据进行embedding编码之后,将其与数值型数据进行组合,即可得到一个广告的总的特征向量,那么基于这个向量,我们就可以去计算该广告的可能点击率(ctr)。

 且deepcrossing采用了残差连接的方式。残差连接的好处在于可以训练更深层的网络,获取更多的交叉特征,且其捕捉细微的变化的能力的更强,可以训练出更精细的模型。

下面给出deepcrossing的大致架构图:


图片引用自:https://blog.csdn.net/wuzhongqiang/article/details/108948440

 总结而言,deepcrossing模型结构很简单,目前多用于ctr点击率的预测。

pytorch复现

 尝试使用pytorch复现,写了一个晚上的代码。。。然后出现了不少bug,还没有解决完,但是宿舍晚上要断电,台式机没电了,所以只能先上传bug not fixed版本到GitHub,呜呜。仓库地址https://github.com/EternalStarICe/recommendation-system-model

顺便附上一些bug截图,希望工作人员检查的时候可以帮我看看代码哪里出了问题😟(如果有时间帮忙当然是最好啦!)

  1. device为cuda时

     我感觉问题可能出在我deepcrossing定义时,每个类别变量我都使用了一个nn.Embedding来处理,然而所有的nn.Embedding都存储在self.embeddings(这是一个list),而该list不在gpu上,所以不是同一个设备?

  2. cuda使用错误之后尝试使用cpu,将device切换为cpu

     这个bug还没来得及去查阅资料处理和解决😭。

细节(发现的问题)

 也是同小组的队友发现了代码中的问题并在小组中进行讨论。(截图来自队友)

 可以发现,这里红框中标出的变量是data,但是对于上面的数据预处理函数data_preprocess而言,其返回了train_data,那么这里应该是使用train_data而不应该是data,但是经过print输出发现,此时data和train_data的值是相同的,于是怀疑是data_preprocess就已经将data原值进行了修改,所以此处data也可以正常使用。后经查阅资料[2],发现若函数得到的参数是一个修改的值如列表等,在函数内部对变量进行修改时会直接修改原值。

如果想要避免直接对原值进行修改的话,可以在参数传入进函数之后生成一个copy对象,对copy对象进行修改。

tensorflow&keras使用初体验

 之前一直都是使用的pytorch进行深度学习网络的构建,而这次的学习代码提供的是tf版本,所以借此契机正好学习一下tensorflow。
 一开始看到keras这种API形式搭建网络框架的形式,感觉很不习惯。在pytorch中,需要自定义一个类来实现网络框架,需要自定义每一层的,定义forward如何处理数据。而keras通过提供的API(各种层),可以很轻松地完成pytorch需要花费大量功夫实现的功能(当然我菜也是很大一部分原因orz)。

 比如通过绑定字典的key和input的name可以指定数据传入,而pytorch需要对数据格式进行变换才能满足网络上的定义的数据顺序,就没那么方便。(keras真香)

 之前还在因为不会tf看不懂其他人tf代码,就借这次机会学学tf吧😆,冲!

参考资料

  1. https://blog.csdn.net/wuzhongqiang/article/details/108948440
  2. https://www.cnblogs.com/sogeisetsu/articles/11595275.html
  3. https://github.com/datawhalechina/team-learning-rs/tree/master/DeepRecommendationModel
  4. https://www.cnblogs.com/xiaxiaoxu/p/9742452.html

0 条评论

发表评论