[pai-diffusion]pai的easynlp的clip模型训练

2023-09-21 19:15:23

EasyNLP带你玩转CLIP图文检索 - 知乎作者:熊兮、章捷、岑鸣、临在导读随着自媒体的不断发展,多种模态数据例如图像、文本、语音、视频等不断增长,创造了互联网上丰富多彩的世界。为了准确建模用户的多模态内容,跨模态检索是跨模态理解的重要任务,…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/528476134

initialize_easynlp()->

train_dataset = CLIPDataset(pretrained_model_name_or_path=get_pretrain_model_path("alibaba-pai/clip_chinese_roberta_base_vit_base"),
    data_file="MUGE_MR_train_base64_part.tsv",
    max_seq_length=32,
    input_schema="text:str:1,image:str:1",
    first_sequence="text",
    second_sequence="image",
    is_training=True)
valid_dataset = CLIPDataset()

model = get_application_model(app_name='clip',...)
- easynlp.appzoo.api.ModelMapping->CLIPApp
- easynlp.appzoo.clip.model.py->CLIPApp
- CHINESE_CLIP->
- self.visual = VisualTransformer()
- self.bert = BertModel()

trainer = Trainer(model,train_dataset,user_defined_parameters,  
                evaluator=get_application_evaluator(app_name="clip",valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32))

trainer.train()
- for _epoch in range(self._first_epoch,int(args.epoch_num)):
      for _step,batch in enumerate(self._train_loader):    
          label_ids = batch.pop()
          forward_outputs = self._model(batch)
          loss_dict = self.model_module.compute_loss(forward_outputs,label_ids)
          _loss = loss_dict('loss')
          
          _loss.backward()

model = get_application_model_evaluation()
evaluator = get_application_evaluator()
evaluator.evaluate(model)

数据处理:

import os
import base64
import multiprocessing
from tqdm import tqdm


def process_image(image_path):
    # 从图片路径中提取中文描述
    image_name = os.path.basename(image_path)
    description = os.path.splitext(image_name)[0]

    # 将图片转换为 Base64 编码
    with open(image_path, 'rb') as f:
        image_data = f.read()
        base64_data = base64.b64encode(image_data).decode('utf-8')

    return description, base64_data


def generate_tsv(directory):
    image_paths = [os.path.join(directory, filename) for filename in os.listdir(directory) if
                   filename.endswith(('.jpg', '.png'))]

    with multiprocessing.Pool() as pool, tqdm(total=len(image_paths), desc='Processing Images') as pbar:
        results = []
        for result in pool.imap_unordered(process_image, image_paths):
            results.append(result)
            pbar.update(1)

    with open(
            '/home/image_team/image_team_docker_home/lgd/e_commerce_sd/data/vcg_furnitures_text_image/vcg_furnitures_train.tsv',
            'w', encoding='utf-8') as f:
        for description, base64_data in results:
            line = f"{description}\t{base64_data}\n"
            f.write(line)


if __name__ == '__main__':
    target_directory = "/home/image_team/image_team_docker_home/lgd/e_commerce_sd/data/vcg_furnitures_text_image/vcg_furnitures_train/img_download/"
    # import pdb;pdb.set_trace()
    generate_tsv(target_directory)

训练代码:

import torch.cuda
from easynlp.appzoo import CLIPDataset
from easynlp.appzoo import get_application_predictor, get_application_model, get_application_evaluator, \
    get_application_model_for_evaluation
from easynlp.core import Trainer, PredictorManager
from easynlp.utils import initialize_easynlp, get_args, get_pretrain_model_path
from easynlp.utils.global_vars import parse_user_defined_parameters


def main():
    # /root/.easynlp/modelzoo中
    train_dataset = CLIPDataset(
        pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),
        data_file=args.tables.split(",")[0],
        max_seq_length=args.sequence_length,
        input_schema=args.input_schema,
        first_sequence=args.first_sequence,
        second_sequence=args.second_sequence,
        is_training=True)

    valid_dataset = CLIPDataset(
        # 预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"alibaba-pai/clip_chinese_roberta_base_vit_base"以得到其路径,并自动下载模型
        pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),
        data_file=args.tables.split(",")[-1],
        # "data/pai/MUGE_MR_valid_base64_part.tsv"
        max_seq_length=args.sequence_length,  # 文本最大长度,超过将截断,不足将padding
        input_schema=args.input_schema,  # 输入tsv数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如label、sent1等
        first_sequence=args.first_sequence,  # 用于说明input_schema中哪些字段作为第一/第二列输入数据
        second_sequence=args.second_sequence,
        is_training=False)  # 是否为训练过程,train_dataset为True,valid_dataset为False

    model = get_application_model(
        app_name=args.app_name,  # 任务名称,这里选择文本分类"clip"
        pretrained_model_name_or_path=get_pretrain_model_path(
            args.pretrained_model_name_or_path),
        user_defined_parameters=user_defined_parameters
        # user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters
    )

    trainer = Trainer(model=model,
                      train_dataset=train_dataset,
                      user_defined_parameters=user_defined_parameters,
                      evaluator=get_application_evaluator(app_name=args.app_name,
                                                          valid_dataset=valid_dataset,
                                                          user_defined_parameters=user_defined_parameters,
                                                          eval_batch_size=32))
    trainer.train()

    # 模型评估
    model = get_application_model_for_evaluation(app_name=args.app_name,
                                                 pretrained_model_name_or_path=args.checkpoint_dir,
                                                 user_defined_parameters=user_defined_parameters)

    evaluator = get_application_evaluator(app_name=args.app_name,
                                          valid_dataset=valid_dataset,
                                          user_defined_parameters=user_defined_parameters,
                                          eval_batch_size=32)
    model.to(torch.cuda.current_device())
    evaluator.evaluate(model=model)

    # 模型预测
    if test:
        predictor = get_application_predictor(app_name="clip",
                                              model_dir="./outputs/clip_model/",
                                              first_sequence="text",
                                              second_sequence="image",
                                              sequence_length=32,
                                              user_defined_parameters=user_defined_parameters)

        predictor_manager = PredictorManager(predictor=predictor,
                                             input_file="data/vcg_furnitures_text_image/vcg_furnitures_test.tsv",
                                             input_schema="text:str:1",
                                             output_file="text_feat.tsv",
                                             output_schema="text_feat",
                                             append_cols="text",
                                             batch_size=2)
        predictor_manager.run()


if __name__ == "__main__":
    initialize_easynlp()
    args = get_args()
    user_defined_parameters = parse_user_defined_parameters(
        'pretrain_model_name_or_path=alibaba-pai/clip_chinese_roberta_base_vit_base')
    args.checkpoint_dir = "./outputs/clip_model/"
    args.pretrained_model_name_or_path = "alibaba-pai/clip_chinese_roberta_base_vit_base"
    # args.n_gpu = 3
    # args.worker_gpu = "1,2,3"
    args.app_name = "clip"
    args.tables = "data/pai/MUGE_MR_train_base64_part.tsv,data/pai/MUGE_MR_valid_base64_part.tsv"
    # "data/vcg_furnitures_text_image/vcg_furnitures_train.tsv," \
    #               "data/vcg_furnitures_text_image/vcg_furnitures_test.tsv"
    # "data/pai/MUGE_MR_train_base64_part.tsv,data/pai/MUGE_MR_valid_base64_part.tsv"
    args.input_schema = "text:str:1,image:str:1"
    args.first_sequence = "text"
    args.second_sequence = "image"
    args.learning_rate = 1e-4
    args.epoch_num = 1000
    args.random_seed = 42
    args.save_checkpoint_steps = 200
    args.sequence_length = 32
    # args.train_batch_size = 2
    args.micro_batch_size = 32

    test = False

    main()

# python -m torch.distributed.launch --nproc_per_node 4 tools/train_pai_chinese_clip.py


说一点自己的想法,在我自己工作之初,我很喜欢去拆解一些框架,例如openmm系列,但其实大部分在训练过程上都是相似的,大可不必,在改动上,也没有必要对其进行流程上的大改动,兼具百家之长,了解整体pipeline,更加专注在pipeline实现和效果导向型的结果提交更加有效。

更多推荐

云原生微服务治理 第五章 Spring Cloud Netflix 之 Ribbon

系列文章目录第一章Java线程池技术应用第二章CountDownLatch和Semaphone的应用第三章SpringCloud简介第四章SpringCloudNetflix之Eureka第四章SpringCloudNetflix之Ribbon文章目录系列文章目录@[TOC](文章目录)前言1、负载均衡1.1、服务端负

Python in Visual Studio Code 2023年9月更新

作者:CourtneyWebster-ProgramManager,PythonExtensioninVisualStudioCode排版:AlanWang我们很高兴地宣布VisualStudioCode的Python和Jupyter扩展将于2023年9月发布!此版本包括以下内容:•将Python的“Recreate”

uniapp----微信小程序 日历组件(周日历&& 月日历)【Vue3+ts+uView】

uniapp----微信小程序日历组件(周日历&&月日历)【Vue3+ts+uView】用Vue3+ts+uView来编写日历组件;存在周日历和月日历两种显示方式;高亮显示当天日期,红点渲染有数据的日期,点击显示数据1.calendar-week-mouth组件代码<template><viewclass="calen

虹科分享 | 谷歌Vertex AI平台使用Redis搭建大语言模型

文章来源:虹科云科技点此阅读原文基础模型和高性能数据层这两个基本组件始终是创建高效、可扩展语言模型应用的关键,利用Redis搭建大语言模型,能够实现高效可扩展的语义搜索、检索增强生成、LLM缓存机制、LLM记忆和持久化。有Redis加持的大语言模型可应用于文档检索、虚拟购物助手、客户服务助理等,为企业带来益处。一、语言

服务器的架构有哪些

服务器的架构有哪些1、单体架构软件设计经典的3层模型是表现层,业务逻辑层,数据访问层。典型的单体架构就是将所有的业务场景的表现层,业务逻辑层,数据访问层放在一个工程中最终经过编译,打包,部署在一台服务器上。2、垂直架构垂直架构是将一个大项目,按照业务场景纵向拆分为互不相干的单体架构的项目。3、前后端分离前后端分离是横向

近年来国内室内定位领域硕士论文选题的现状与趋势

目录一、前言二、选题的目的和意义三、选题现状分析四、选题趋势分析一、前言本博文采用了图表统计法分析了近5年来100余篇高被引室内定位领域硕士论文选题的现状,并从选题现状中得出了该领域选题的大致趋势。本文还通过分析该领域硕士毕业论文选题的现状和趋势,对未来该领域选题提出了自己的见解和展望。二、选题的目的和意义无论是大学生

成为威胁:网络安全中的动手威胁模拟案例

不断变化的网络威胁形势要求组织为其网络安全团队配备必要的技能来检测、响应和防御恶意攻击。然而,在研究中发现并继续探索的最令人惊讶的事情是,欺骗当前的网络安全防御是多么容易。防病毒程序建立在庞大的签名数据库之上,只需更改程序内的文本这样简单的操作就很容易崩溃。这同样适用于网络签名以及端点检测和响应。防御技术主要关注某些行

区块链安全,哈希函数暴露的攻击向量与对策

区块链安全,哈希函数暴露的攻击向量与对策简介LengthExtensionAttack是一种与某些特定类型的哈希函数(如MD5,SHA-1和SHA-2)的特性有关的攻击。简单来说,这种攻击利用了一个事实,即知道H(message)和message的长度,我们可以轻松计算出H(message||padding||exte

QTday3

#include"widget.h"Widget::Widget(QWidget*parent):QWidget(parent){this->setFixedSize(600,450);//将窗口固定大小this->setWindowIcon(QIcon(":/wodepeizhenshi.png"));//设置窗口图

驱动开发---基于gpio子系统编写LED灯的驱动

一、GPIO子系统相关API1.解析GPIO相关的设备树节点structdevice_node*of_find_node_by_path(constchar*path)功能:根据设备树节点路径解析设备树节点信息参数:path:设备树所在的节点路径/mynode@0X12345678返回值:成功返回目标节点首地址,失败返

第33节——useRef

一、概念useRef,他的作用是“勾住”某些组件挂载完成或重新渲染完成后才拥有的某些对象,并返回该对象的引用。该引用在组件整个生命周期中都固定不变,该引用并不会随着组件重新渲染而失效。返回一个可变的ref对象,该对象只有个current属性,初始值为传入的参数(initialValue)。返回的ref对象在组件的整个生

热文推荐