argsparse

参数配置库-argsparse

argsparse是python的命令行解析的标准模块,内置于python,不需要安装。这个库可以让我们直接在命令行中就可以向程序中传入参数并让程序运行。

引用及使用

声明ArgumentParser对象后,可以使用add_argument方法在parser中添加参数。type用于设置参数的数据类型,default为默认值,help为参数出错时的提示。还可以为参数设置required为True或者False,代表程序执行时必须设定该参数的值,否则会报错。其他的参数设置可以参考命令行选项、参数和子命令解析器

1
2
3
4
5
6
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--task",type=str,default="nothing",help="执行任务名称")
args = parser.parse_args(args=[])
print(args.task) # nothing

其中”–task”表示task为可选参数,如果没有”–”表示为必选参数,即required,如果运行程序时未指定参数值,即便设置了default也仍然会报错。

使用json生成config文件

训练模型过程中,通常会涉及到多模型的性能比较,不同模型的参数设置往往都是不相同的,使用上述的配置参数方式可能便捷性较差。我们可以将所有的参数分成两部分,一部分是不怎么会调整的参数,直接使用add_argument方法定义,另一部分是模型中可能变动、可能需要调整的参数,将这些参数写成一个json。这样,不同模型调用不同json中的设置,以降低使用不同模型进行试验时调整参数的复杂度。
对于类似输出路径、预训练模型路径等不怎么需要变更的参数使用add_argument方法设定,对于模型学习率等参数,我们记录在json当中。

1
2
3
4
5
6
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='./config/something.json')
parser.add_argument('--save_path', type=str, default='./outputs')
parser.add_argument('--bert_name', type=str, default=r"E:\MyPython\Pre-train-Model\mc-bert-base")
parser.add_argument('--device', type=str, default="cuda")
args = parser.parse_args(args=[])

我们已经在args中设定了config的路径,使用json读取这个路径就可以得到剩余的参数。我们可以声明一个config类,用来存json当中的每一项参数,然后将两部分参数整合在一起就是所有参数了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Config:
def __init__(self, args):

# 使用json读取参数
with open(args.config, "r", encoding="utf-8") as f:
config = json.load(f)
# 设置需要调整的参数
self.loss_type = config["loss_type"]
self.learning_rate = config["learning_rate"]
self.bert_learning_rate = config["bert_learning_rate"]
self.weight_decay = config["weight_decay"]

# 将两部分参数整合在一起
for k, v in args.__dict__.items():
if v is not None:
self.__dict__[k] = v

def __repr__(self):
return "{}".format(self.__dict__.items())

这样我们就可以在程序中实例化config对象,通过访问对象属性的方式来获取预先设定好的参数。

1
2
config = Config(args)
print(config.learning_rate)

题外话

通常情况下,使用add_augment定义的往往都是单一的参数,实际上在add_augment方法中可以使用nargs关键字来获取多个参数。举个栗子叭:

1
2
3
4
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--files", type=str, default=[r"C:/Users/.ssh/id_rsa",r"C:/Users/.ssh/id_rsa.pub"],nargs='+')
args = parser.parse_args(args=[])

此时我们单独打印args.files可以得到:

1
2
args.files
# ['C:/Users/.ssh/id_rsa', 'C:/Users/.ssh/id_rsa.pub']

可以通过遍历得到每一个传入的路径:

1
2
3
4
for i in args.files:
print(i)
# C:/Users/.ssh/id_rsa
# C:/Users/.ssh/id_rsa.pub

huggingface的参数配置-HFArgumentParser

本节内容来源于transformer.HfArgumentParser的使用
HfArgumentParser是Transformer框架中的命令行解析工具,它是ArgumentParser的子类。用于从类对象中创建解析对象。
在python中,我们习惯于将有关联的一些参数放在一个类当中,HfArgumentParser可以将类对象中的实例属性转换成转换为解析参数。需要注意的是这里的类对象必须是通过@dataclass()创建的类对象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import HfArgumentParser
from dataclasses import dataclass,field
from typing import Optional
@dataclass()
class BasicSetting():
# a:str = field(default="bagging")
model_path : str = field(default=r"E:\MyPython\Pre-train-Model\bert-base-chinese")
@dataclass()
class HyperParameters():
bert_learning_rate:float = field(
default=3e-5
)
parser = HfArgumentParser((BasicSetting,HyperParameters))
basic,hyper = parser.parse_args_into_dataclasses()
print(basic.model_path) # E:\MyPython\Pre-train-Model\bert-base-chinese
print(hyper.bert_learning_rate) # 3e-05

dataclass

dataclass是Python3.7 开始引入的一个新功能, dataclass提供了开箱即用的方法来创建自定义数据, 可以直接实例化、打印和比较数据类实例。
dataclass 可以认为是提供了一个简写__init__方法的语法糖。类型注释是必填项 (不限制数据类型时, 添加typing.Any为类型注释), 默认值的传递方式和__init__方法的参数格式一致.

1
2
3
4
5
6
from dataclasses import dataclass
@dataclass
class HyperParameter:
bert_learning_rate: float
learning_rate: float
print(HyperParameter(5e-5,1e-3)) # HyperParameter(bert_learning_rate=5e-05, learning_rate=0.001)

field

当我们尝试使用可变的数据类型, 给dataclass做默认值时, 可能会导致多个实例公用一个数据从而引发bug。dataclass 默认阻止使用可变数据做默认值,此时需要使用field中的default_factory。

1
2
3
4
5
6
7
8
9
10
11
from dataclasses import dataclass
from typing import List
@dataclass
class HyperParameter:
bert_learning_rate: float
learning_rate: float
hp = HyperParameter(5e-5,1e-3)
@dataclass
class HyperParameter:
paramters : List[HyperParameter] = [hp]
# mutable default <class 'list'> for field paramters is not allowed: use default_factory

此时就需要使用field来完成相应实现。

1
2
3
4
5
6
7
8
9
10
11
12
from dataclasses import dataclass,field
from typing import List
@dataclass
class HyperParameter:
bert_learning_rate: float
learning_rate: float
hp = HyperParameter(5e-5,1e-3)
@dataclass
class HyperParameter:
paramters : List[HyperParameter] = field(default_factory=lambda:[hp])
_hp = HyperParameter()
print(_hp) # HyperParameter(paramters=[HyperParameter(bert_learning_rate=5e-05, learning_rate=0.001)])