1美元训练BERT,教你如何薅谷歌TPU羊毛 | 附Colab代码
晓查 发自 凹非寺
量子位 出品 | 公众号 QbitAI
BERT是谷歌去年推出的NLP模型,一经推出就在各项测试中碾压竞争对手,而且BERT是开源的。只可惜训练BERT的价格实在太高,让人望而却步。
之前需要用64个TPU训练4天才能完成,后来谷歌用并行计算优化了到只需一个多小时,但是需要的TPU数量陡增,达到了惊人的1024个。
那么总共要多少钱呢?谷歌云TPU的使用价格是每个每小时6.5美元,训练完成训练完整个模型需要近4万美元,简直就是天价。
现在,有个羊毛告诉你,在Medium上有人找到了薅谷歌羊毛的办法,只需1美元就能训练BERT,模型还能留存在你的谷歌云盘中,留作以后使用。
准备工作
为了薅谷歌的羊毛,您需要一个Google云存储(Google Cloud Storage)空间。按照Google 云TPU快速入门指南,创建Google云平台(Google Cloud Platform)帐户和Google云存储账户。新的谷歌云平台用户可获得300美元的免费赠送金额。
在TPUv2上预训练BERT-Base模型大约需要54小时。Google Colab并非设计用于执行长时间运行的作业,它会每8小时左右中断一次训练过程。对于不间断的训练,请考虑使用付费的不间断使用TPUv2的方法。
也就是说,使用Colab TPU,你可以在以1美元的价格在Google云盘上存储模型和数据,以几乎可忽略成本从头开始预训练BERT模型。
以下是整个过程的代码下面的代码,可以在Colab Jupyter环境中运行。
设置训练环境
首先,安装训练模型所需的包。Jupyter允许使用’!’直接从笔记本执行bash命令:
!pip install sentencepiece
!git clone https://github.com/google-research/bert
导入包并在Google云中授权:
import os
import sys
import json
import nltk
import random
import logging
import tensorflow as tf
import sentencepiece as spm
from glob import glob
from google.colab import auth, drive
from tensorflow.keras.utils import Progbar
sys.path.append("bert")
from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder
auth.authenticate_user()
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s : %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]
if 'COLAB_TPU_ADDR' in os.environ:
log.info("Using TPU runtime")
USE_TPU = True
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
with tf.Session(TPU_ADDRESS) as session:
log.info('TPU address is ' + TPU_ADDRESS)
# Upload credentials to TPU.
with open('/content/adc.json', 'r') as f:
auth_info = json.load(f)
tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
else:
log.warning('Not connected to TPU runtime')
USE_TPU = False
下载原始文本数据
接下来从网络上获取文本数据语料库。在本次实验中,我们使用OpenSubtitles数据集,该数据集包括65种语言。
与更常用的文本数据集(如维基百科)不同,它不需要任何复杂的预处理,提供预格式化,一行一个句子。
AVAILABLE = {'af','ar','bg','bn','br','bs','ca','cs',
'da','de','el','en','eo','es','et','eu',
'fa','fi','fr','gl','he','hi','hr','hu',
'hy','id','is','it','ja','ka','kk','ko',
'lt','lv','mk','ml','ms','nl','no','pl',
'pt','pt_br','ro','ru','si','sk','sl','sq',
'sr','sv','ta','te','th','tl','tr','uk',
'ur','vi','ze_en','ze_zh','zh','zh_cn',
'zh_en','zh_tw','zh_zh'}
LANG_CODE = "en" #@param {type:"string"}
assert LANG_CODE in AVAILABLE, "Invalid language code selected"
!wget http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2016/mono/OpenSubtitles.raw.'$LANG_CODE'.gz -O dataset.txt.gz
!gzip -d dataset.txt.gz
!tail dataset.txt
你可以通过设置代码随意选择你需要的语言。出于演示目的,代码只默认使用整个语料库的一小部分。在实际训练模型时,请务必取消选中DEMO_MODE复选框,使用大100倍的数据集。
当然,100M数据足以训练出相当不错的BERT基础模型。
DEMO_MODE = True #@param {type:"boolean"}
if DEMO_MODE:
CORPUS_SIZE = 1000000
else:
CORPUS_SIZE = 100000000 #@param {type: "integer"}
!(head -n $CORPUS_SIZE dataset.txt) > subdataset.txt
!mv subdataset.txt dataset.txt
预处理文本数据
我们下载的原始文本数据包含标点符号,大写字母和非UTF符号,我们将在继续下一步之前将其删除。在推理期间,我们将对新数据应用相同的过程。
如果你需要不同的预处理方式(例如在推理期间预期会出现大写字母或标点符号),请修改以下代码以满足你的需求。
regex_tokenizer = nltk.RegexpTokenizer("w+")
def normalize_text(text):
# lowercase text
text = str(text).lower()
# remove non-UTF
text = text.encode("utf-8", "ignore").decode()
# remove punktuation symbols
text = " ".join(regex_tokenizer.tokenize(text))
return text
def count_lines(filename):
count = 0
with open(filename) as fi:
for line in fi:
count += 1
return count
现在让我们预处理整个数据集:
RAW_DATA_FPATH = "dataset.txt" #@param {type: "string"}
PRC_DATA_FPATH = "proc_dataset.txt" #@param {type: "string"}
# apply normalization to the dataset
# this will take a minute or two
total_lines = count_lines(RAW_DATA_FPATH)
bar = Progbar(total_lines)
with open(RAW_DATA_FPATH,encoding="utf-8") as fi:
with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo:
for l in fi:
fo.write(normalize_text(l)+"n")
bar.add(1)
构建词汇表
下一步,我们将训练模型学习一个新的词汇表,用于表示我们的数据集。
BERT文件使用WordPiece分词器,在开源中不可用。我们将在unigram模式下使用SentencePiece分词器。虽然它与BERT不直接兼容,但是通过一个小的处理方法,可以使它工作。
SentencePiece需要相当多的运行内存,因此在Colab中的运行完整数据集会导致内核崩溃。
为避免这种情况,我们将随机对数据集的一小部分进行子采样,构建词汇表。另一个选择是使用更大内存的机器来执行此步骤。
此外,SentencePiece默认情况下将BOS和EOS控制符号添加到词汇表中。我们通过将其索引设置为-1来禁用它们。
VOC_SIZE的典型值介于32000和128000之间。如果想要更新词汇表,并在预训练阶段结束后对模型进行微调,我们会保留NUM_PLACEHOLDERS个token。
MODEL_PREFIX = "tokenizer" #@param {type: "string"}
VOC_SIZE = 32000 #@param {type:"integer"}
SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}
NUM_PLACEHOLDERS = 256 #@param {type:"integer"}
SPM_COMMAND = ('--input={} --model_prefix={} '
'--vocab_size={} --input_sentence_size={} '
'--shuffle_input_sentence=true '
'--bos_id=-1 --eos_id=-1').format(
PRC_DATA_FPATH, MODEL_PREFIX,
VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)
spm.SentencePieceTrainer.Train(SPM_COMMAND)
现在,让我们看看如何让SentencePiece在BERT模型上工作。
下面是使用来自官方的预训练英语BERT基础模型的WordPiece词汇表标记的语句。
>>> wordpiece.tokenize("Colorless geothermal substations are generating furiously")
['color',
'##less',
'geo',
'##thermal',
'sub',
'##station',
'##s',
'are',
'generating',
'furiously']
WordPiece标记器在“##”的单词中间预置了出现的子字。在单词开头出现的子词不变。如果子词出现在单词的开头和中间,则两个版本(带和不带’##’)都会添加到词汇表中。
SentencePiece创建了两个文件:tokenizer.model和tokenizer.vocab。让我们来看看它学到的词汇:
def read_sentencepiece_vocab(filepath):
voc = []
with open(filepath, encoding='utf-8') as fi:
for line in fi:
voc.append(line.split("t")[0])
# skip the first token
voc = voc[1:]
return voc
snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
print("Learnt vocab size: {}".format(len(snt_vocab)))
print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))
运行结果:
Learnt vocab size: 31743
Sample tokens: ['▁cafe', '▁slippery', 'xious', '▁resonate', '▁terrier', '▁feat', '▁frequencies', 'ainty', '▁punning', 'modern']
SentencePiece与WordPiece的运行结果完全相反。从文档中可以看出:SentencePiece首先使用元符号“_”将空格转义为空格,如下所示:
Hello_World。
然后文本被分段为小块:
[Hello] [_Wor] [ld] [.]
在空格之后出现的子词(也是大多数词开头的子词)前面加上“_”,而其他子词不变。这排除了仅出现在句子开头而不是其他地方的子词。然而,这些案件应该非常罕见。
因此,为了获得类似于WordPiece的词汇表,我们需要执行一个简单的转换,从包含它的标记中删除“_”,并将“##”添加到不包含它的标记中。
我们还添加了一些BERT架构所需的特殊控制符号。按照惯例,我们把它们放在词汇的开头。
另外,我们在词汇表中添加了一些占位符token。
如果你希望使用新的用于特定任务的token来更新预先训练的模型,那么这些方法是很有用的。
在这种情况下,占位符token被替换为新的token,重新生成预训练数据,并且对新数据进行微调。
def parse_sentencepiece_token(token):
if token.startswith("▁"):
return token[1:]
else:
return "##" + token
bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))
ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab
bert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]
print(len(bert_vocab))
最后,我们将获得的词汇表写入文件。
VOC_FNAME = "vocab.txt" #@param {type:"string"}
with open(VOC_FNAME, "w") as fo:
for token in bert_vocab:
fo.write(token+"n")
现在,让我们看看新词汇在实践中是如何运作的:
>>> testcase = "Colorless geothermal substations are generating furiously"
>>> bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)
>>> bert_tokenizer.tokenize(testcase)
['color',
'##less',
'geo',
'##ther',
'##mal',
'sub',
'##station',
'##s',
'are',
'generat',
'##ing',
'furious',
'##ly']
创建分片预训练数据(生成预训练数据)
通过手头的词汇表,我们可以为BERT模型生成预训练数据。
由于我们的数据集可能非常大,我们将其拆分为碎片:
mkdir ./shards
split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_
现在,对于每个部分,我们需要从BERT仓库调用create_pretraining_data.py脚本,需要使用xargs命令。
在开始生成之前,我们需要设置一些参数传递给脚本。你可以从自述文件中找到有关它们含义的更多信息。
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = True #@param {type:"boolean"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
# controls how many parallel processes xargs can create
PROCESSES = 2 #@param {type:"integer"}
运行此操作可能需要相当长的时间,具体取决于数据集的大小。
XARGS_CMD = ("ls ./shards/ | "
"xargs -n 1 -P {} -I{} "
"python3 bert/create_pretraining_data.py "
"--input_file=./shards/{} "
"--output_file={}/{}.tfrecord "
"--vocab_file={} "
"--do_lower_case={} "
"--max_predictions_per_seq={} "
"--max_seq_length={} "
"--masked_lm_prob={} "
"--random_seed=34 "
"--dupe_factor=5")
XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}',
VOC_FNAME, DO_LOWER_CASE,
MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)
tf.gfile.MkDir(PRETRAINING_DIR)
!$XARGS_CMD
为数据和模型设置GCS存储,将数据和模型存储到云端
为了保留来之不易的训练模型,我们会将其保留在Google云存储中。
在Google云存储中创建两个目录,一个用于数据,一个用于模型。在模型目录中,我们将放置模型词汇表和配置文件。
在继续操作之前,请配置BUCKET_NAME变量,否则将无法训练模型。
BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
tf.gfile.MkDir(MODEL_DIR)
if not BUCKET_NAME:
log.warning("WARNING: BUCKET_NAME is not set. "
"You will not be able to train the model.")
下面是BERT-base的超参数配置示例:
# use this for BERT-base
bert_base_config = {
"attention_probs_dropout_prob": 0.1,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"type_vocab_size": 2,
"vocab_size": VOC_SIZE
}
with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo:
json.dump(bert_base_config, fo, indent=2)
with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo:
for token in bert_vocab:
fo.write(token+"n")
现在,我们已准备好将模型和数据存储到谷歌云当中:
if BUCKET_NAME:
!gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME
在云TPU上训练模型
注意,之前步骤中的某些参数在此处不用改变。请确保在整个实验中设置的参数完全相同。
BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
VOC_FNAME = "vocab.txt" #@param {type:"string"}
# Input data pipeline config
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 2e-5
TRAIN_STEPS = 1000000 #@param {type:"integer"}
SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}
NUM_TPU_CORES = 8
if BUCKET_NAME:
BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
BUCKET_PATH = "."
BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)
VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)
CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")
INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)
bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)
input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))
log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Using {} data shards".format(len(input_files)))
准备训练运行配置,建立评估器和输入函数,启动BERT!
model_fn = model_fn_builder(
bert_config=bert_config,
init_checkpoint=INIT_CHECKPOINT,
learning_rate=LEARNING_RATE,
num_train_steps=TRAIN_STEPS,
num_warmup_steps=10,
use_tpu=USE_TPU,
use_one_hot_embeddings=True)
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=BERT_GCS_DIR,
save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
num_shards=NUM_TPU_CORES,
per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=USE_TPU,
model_fn=model_fn,
config=run_config,
train_batch_size=TRAIN_BATCH_SIZE,
eval_batch_size=EVAL_BATCH_SIZE)
train_input_fn = input_fn_builder(
input_files=input_files,
max_seq_length=MAX_SEQ_LENGTH,
max_predictions_per_seq=MAX_PREDICTIONS,
is_training=True)
执行!
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)
最后,使用默认参数训练模型需要100万步,约54小时的运行时间。如果内核由于某种原因重新启动,可以从断点处继续训练。
以上就是是在云TPU上从头开始预训练BERT的指南。
下一步
好的,我们已经训练好了模型,接下来可以做什么?
1、使用预训练的模型作为通用的自然语言理解模块;
2、针对某些特定的分类任务微调模型;
3、使用BERT作为构建块,去创建另一个深度学习模型。
传送门
原文地址:
https://towardsdatascience.com/pre-training-bert-from-scratch-with-cloud-tpu-6e2f71028379
Colab代码:
https://colab.research.google.com/drive/1nVn6AFpQSzXBt8_ywfx6XR8ZfQXlKGAz
作者系网易新闻·网易号“各有态度”签约作者
— 完 —
加入社群 | 与优秀的人交流
小程序 | 全类别AI学习教程
量子位 QbitAI · 头条号签约作者
վ'ᴗ' ի 追踪AI技术和产品新动态
喜欢就点「在看」吧 !
关注公众号:拾黑(shiheibook)了解更多
[广告]赞助链接:
四季很好,只要有你,文娱排行榜:https://www.yaopaiming.com/
让资讯触达的更精准有趣:https://www.0xu.cn/