李沐69_BERT训练集——自学笔记

NLP里的迁移学习

1.使用预训练好的模型来抽取词、句子的特征,例如word2vec或语言模型

2.不更新预训练好的模型

3.需要构建新的网络来抓取新任务需要的信息:word2vec忽略了时序信息,语言模型只看了一个方向

BERT的动机

1.基于微调的NLP模型

2.预训练的模型抽取了足够多的信息

3.新的任务只需要增加一个简单的输出层

对输入的修改

1.每个样本是一个句子对

2.加入额外的片段嵌入

3.位置编码可学习

预训练任务1:带掩码的语言模型

1.Transformer的编码器是双向,标准语言模型要求单向

2.带掩码的语言模型每次随机(15%)将一些词元换成mask

3.因为微调任务中不出现mask:80%概率下将选中的词元变成mask、10%概率换成一个随机词元、10%保持原有词元

预训练任务2:下一句子预测

1.预测下一个句子对中两个句子是不是相邻

2.训练样本中:50%概率选择相邻句子对、50%概率选择随机句子对

!pip install d2l==0.17.6  ### 很重要,不要下载错了,对于colab
import torch
from torch import nn
from d2l import torch as d2l

get_tokens_and_segments将一个句子或两个句子作为输入,然后返回BERT输入序列的标记及其相应的片段索引。

def get_tokens_and_segments(tokens_a, tokens_b=None):
    """获取输入序列的词元及其片段索引"""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0和1分别标记片段A和B
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

与TransformerEncoder不同,BERTEncoder使用片段嵌入和可学习的位置嵌入。

class BERTEncoder(nn.Module):
    """BERT编码器"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f"{i}", d2l.EncoderBlock(
                key_size, query_size, value_size, num_hiddens, norm_shape,
                ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
        # 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
                                                      num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

假设词表大小为10000,为了演示BERTEncoder的前向推断,让我们创建一个实例并初始化它的参数。

vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input,
                      ffn_num_hiddens, num_heads, num_layers, dropout)

我们将tokens定义为长度为8的2个输入序列,其中每个词元是词表的索引。使用输入tokens的BERTEncoder的前向推断返回编码结果,其中每个词元由向量表示,其长度由超参数num_hiddens定义。此超参数通常称为Transformer编码器的隐藏大小(隐藏单元数)。

tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None)
encoded_X.shape
torch.Size([2, 8, 768])

下面的MaskLM类来预测BERT预训练的掩蔽语言模型任务中的掩蔽标记。预测使用单隐藏层的多层感知机(self.mlp)。在前向推断中,它需要两个输入:BERTEncoder的编码结果和用于预测的词元位置。输出是这些位置的预测结果。

class MaskLM(nn.Module):
    """BERT的掩蔽语言模型任务"""
    def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.Linear(num_hiddens, vocab_size))

    def forward(self, X, pred_positions):
        num_pred_positions = pred_positions.shape[1]
        pred_positions = pred_positions.reshape(-1)
        batch_size = X.shape[0]
        batch_idx = torch.arange(0, batch_size)
        # 假设batch_size=2,num_pred_positions=3
        # 那么batch_idx是np.array([0,0,0,1,1,1])
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
        masked_X = X[batch_idx, pred_positions]
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

为了演示MaskLM的前向推断,我们创建了其实例mlm并对其进行了初始化。回想一下,来自BERTEncoder的正向推断encoded_X表示2个BERT输入序列。我们将mlm_positions定义为在encoded_X的任一输入序列中预测的3个指示。mlm的前向推断返回encoded_X的所有掩蔽位置mlm_positions处的预测结果mlm_Y_hat。对于每个预测,结果的大小等于词表的大小。

mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape
torch.Size([2, 3, 10000])

通过掩码下的预测词元mlm_Y的真实标签mlm_Y_hat,我们可以计算在BERT预训练中的遮蔽语言模型任务的交叉熵损失。

mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l.shape
torch.Size([6])

下一句预测(Next Sentence Prediction)

class NextSentencePred(nn.Module):
    """BERT的下一句预测任务"""
    def __init__(self, num_inputs, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)

    def forward(self, X):
        # X的形状:(batchsize,num_hiddens)
        return self.output(X)

NextSentencePred实例的前向推断返回每个BERT输入序列的二分类预测。

encoded_X = torch.flatten(encoded_X, start_dim=1)
# NSP的输入形状:(batchsize,num_hiddens)
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)
nsp_Y_hat.shape
torch.Size([2, 2])

还可以计算两个二元分类的交叉熵损失。

nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l.shape
torch.Size([2])

整合编码

class BERTModel(nn.Module):
    """BERT模型"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,
                    ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                    dropout, max_len=max_len, key_size=key_size,
                    query_size=query_size, value_size=value_size)
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None,
                pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # 用于下一句预测的多层感知机分类器的隐藏层,0是“<cls>”标记的索引
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/576744.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

又重新搭了个个人博客

哈喽大家好&#xff0c;我是咸鱼。 前段时间看到一个学弟写了篇用 Hexo 搭建博客的教程&#xff0c;心中沉寂已久的激情重新被点燃起来。&#xff08;以前搞过一个个人网站&#xff0c;但是因为种种原因最后不了了之&#xff09; 于是花了一天时间参考教程搭了个博客网站&…

广工电工与电子技术实验报告-8路彩灯循环控制电路

实验代码 module LED_water (clk,led); input clk; output [7:0] led; reg [7:0] led; integer p; reg clk_1Hz; reg [7:0] current_state, next_state; always (posedge clk) begin if(p25000000-1)begin …

JVM(Jvm如何管理空间?对象如何存储、管理?)

Jvm如何管理空间&#xff08;Java运行时数据区域与分配空间的方式&#xff09; ⭐运行时数据区域 程序计数器 程序计数器&#xff08;PC&#xff09;&#xff0c;是一块较小的内存空。它可以看作是当前线程所执行的字节码的行号指示器。Java虚拟机的多线程是通过时间片轮转调…

Flutter 从 Assets 中读取 JSON 文件:指南 [2024]

在本教程中&#xff0c;我们将探讨如何从 Flutter 项目中的 asset 中读取 JSON 文件。您将找到详细的解释、实际示例和最佳实践&#xff0c;使您的 JSON 文件处理顺利高效。那么&#xff0c;让我们深入了解 Flutter 和 JSON 的世界吧&#xff01; 从 asset 中读取 JSON 文件 …

RTK负载(4K可见光+高分热成像+超广角+激光测距)四光AI智能识别跟踪吊舱技术详解

无人机光电吊舱的RTK负载&#xff08;4K可见光高分热成像超广角激光测距&#xff09;AI智能识别跟踪吊舱技术是一种高度集成和先进的无人机观测系统。系统结合了无人机的飞行能力和光电吊舱的多功能传感器&#xff0c;通过集成RTK&#xff08;实时动态差分定位&#xff09;技术…

WebGL开发框架比较

WebGL开发框架提供了一套丰富的工具和API&#xff0c;使得在Web浏览器中创建和操作3D图形变得更加容易。以下是一些流行的WebGL开发框架及其各自的优缺点。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&#xff0c;欢迎交流合作。 1.Three.js 优点&#xff1a…

WoodMart主题下载:为您的电商网站带来自然而优雅的购物体验

在电子商务的激烈竞争中&#xff0c;一个设计精良、用户友好的在线商店是吸引和保留客户的关键。WoodMart主题&#xff0c;作为一款专为Shopify平台设计的高级主题&#xff0c;以其自然美学和强大的功能&#xff0c;帮助您的商店在众多竞争对手中脱颖而出。 [WoodMart主题的核…

运行游戏提示dll文件丢失,分享多种有效的解决方法

在我们日常频繁地利用电脑进行娱乐活动&#xff0c;特别是畅玩各类精彩纷呈的电子游戏时&#xff0c;常常会遭遇一个令人困扰的问题。当我们满怀期待地双击图标启动心仪的游戏程序&#xff0c;准备全身心投入虚拟世界时&#xff0c;屏幕上却赫然弹出一条醒目的错误提示信息&…

【书生浦语第二期实战营学习作业笔记(二)】

书生浦语第二期实战营学习作业&笔记(二) 操作文档&#xff1a;https://github.com/InternLM/Tutorial/blob/camp2/helloworld/hello_world.md 基础作业 &#xff1a; 使用 InternLM2-Chat-1.8B 模型生成 300 字的小故事&#xff1a; 八戒部署&#xff08;笔记&#xff0…

基于灰狼算法的综合能源系统多时间尺度优化调度(附matlab程序)

0.代码链接 基于灰狼算法的综合能源系统多时间尺度优化调度&#xff08;MATLAB程序&#xff09;资源-CSDN文库 1.简述 对于冷、热、电联供综合能源系统优化问题&#xff0c;为了提高可再生能源利用率&#xff0c;故以弃风、弃光量最小和综合能源系统运行经济性为优化目标&…

mac配置maven

在 macOS 上配置 Maven 也相对简单。以下是一种常用的方法&#xff1a; 1. 安装maven **下载 Maven&#xff1a;**首先&#xff0c;你需要从 Maven 官网&#xff08;https://maven.apache.org/download.cgi&#xff09;下载最新版本的 Maven。你可以选择二进制压缩包&#xf…

redis常用数据结构

redis常用数据结构 Redis 底层在实现下面数据结构的时候&#xff0c;会进行特定的优化&#xff0c;来达到节省时间/空间的效果。 内部结构 String raw&#xff08;最基本的字符串&#xff09;&#xff0c;int&#xff08;实现计数功能&#xff0c;当value为整数的时候会用整…

JPEG图像常用加密算法简介

JPEG图像加密算法 目前&#xff0c;JPEG图像加密算法可以分成异或加密、置乱加密和置乱与异或组合加密。下面对这三种加密方式进行阐述。 (1) 异或加密 文献[1]提出了一种基于异或加密的JPEG图像的RDH-EI方案。该算法通过对AC系数的ACA和图像的量化表进行流密码异或&#xf…

Feign功能详解、使用步骤、代码案例

简介&#xff1a;Feign是Netflix开发的声明式&#xff0c;模板化的HTTP客户端&#xff0c;简化了HTTP的远程服务的开发。Feign是在RestTemplate和Ribbon的基础上进一步封装&#xff0c;使用RestTemplate实现Http调用&#xff0c;使用Ribbon实现负载均衡。我们可以看成 Feign R…

经典的免费wordpress模板

这款经典的免费WordPress模板以鲜艳的红色为主调&#xff0c;充满了活力与热情。设计简洁而现代&#xff0c;适合各种类型的项目网站。模板采用响应式设计&#xff0c;确保在不同设备和屏幕尺寸上都能呈现出完美的视觉效果。 红色象征着激情、活力和自信&#xff0c;这款模板…

ubuntu子系统密码忘记了,怎么办?

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…

18种WEB常见漏洞:揭秘网络安全的薄弱点

输入验证漏洞: 认证和会话管理漏洞: 安全配置错误: 其他漏洞: 防范措施: Web 应用程序是现代互联网的核心&#xff0c;但它们也容易受到各种安全漏洞的影响。了解常见的 Web 漏洞类型&#xff0c;对于开发人员、安全测试人员和普通用户都至关重要。以下将介绍 18 种常见的 …

HttpClient工具类编写

HttpClient 介绍 HttpClient是Apache Jakarta Common下的一个子项目&#xff0c;它提供了一个高效的、最新的、功能丰富的支持HTTP协议的客户端编程工具包。它支持HTTP协议最新的版本和建议&#xff0c;并实现了Http1.0和Http1.1协议。 HttpClient具有可扩展的面向对象的结构…

电脑壁纸怎么设置?简单3步,让你的电脑桌面变得更合心意

电脑壁纸是我们每天在电脑上工作和娱乐时不可或缺的一部分。一张精美的电脑壁纸&#xff0c;既能提升我们的工作效率&#xff0c;也能为我们带来愉悦的心情。无论是静谧的自然风光、抽象的艺术设计&#xff0c;还是心动的明星照片&#xff0c;都可以在电脑壁纸的世界里找到自己…

HF区块链链码基础

链码生命周期 一 . 链码准备 准备文件 . 在测试目录下创建chaincode,拷贝测试链码进 chaincode目录,拷贝 set-env.sh 脚本进 scripts 目录 二. 打包链码 打包测试链码 export FABRIC_CFG_PATH${PWD}/config peer lifecycle chaincode package ./chaincode/chaincode_basic.…