This example use chinese news dataset from here to fine tune the bert pretrained model for classification, and save the fine-tuned model to test the result through a rest api deployed by flask. Also the basic thinking come from this blog BERT Fine-Tuning Tutorial with PyTorch. Pytorch and tensorflow are both used in this work, especially a library named pytorch-pretrained-bert which help to use pretrained model like BERT, GPT, GPT2 to downstream tasks.
BERT is a popular pretrained model From Google. Here is some great post for recommend:
Download the BERT-Base, Chinese model and unzip the file
wget https://github.com/fate233/toutiao-text-classfication-dataset/blob/master/toutiao_cat_data.txt.zip
python bert_for_classification.py --output_dir your/outout/dir --data_dir toutiao/dataset/dir --data_name toutiao_cat_data.txt --is_add_key_words True
4. Set the output file position above to api file, and run the command below to start the flask service
Line 9: model = torch.load('output')
python classification-api.py
curl -X POST http://xx.xx.xx.xx:8000/predict -H 'Content-Type: application/json' -d '{ "text":"珍惜当下 局部新一轮升浪悄然开启" ,"label":"财经"}' |jq
{"Predict Label":"财经 财经","True Label":"财经"}
curl -X POST http://xx.xx.xx.xx:8000/predict -H 'Content-Type: application/json' -d '{ "text":"美国要在亚太建导弹基地?普京:给你脸了是不是!" ,"label":"军事"}' |jq
{"Predict Label":"国际 国际","True Label":"军事"}