编辑
2025-04-19
深度学习
00

目录

imagemaxpixels
videomaxpixels
trustremotecode
lora_rank
lora_target
cutoff_len
max_samples
overwrite_cache
logging_steps
save_steps
overwriteoutputdir
saveonlymodel
report_to
resumefromcheckpoint

image_max_pixels

image_max_pixels, 这里的 area 就是宽*高

python
# 在src/llamafactory/data/mm_plugin.py中定义 # 控制图像处理时的最大像素数量 # 如果图像超过这个像素数,会被调整大小 def get_image_processor_preprocess_params(image_processor): params = {} if hasattr(image_processor, "crop_size"): # for CLIP params["crop_size"] = { "height": image_processor.crop_size["height"], "width": image_processor.crop_size["width"] } params["size"] = max(params["crop_size"]["height"], params["crop_size"]["width"]) elif hasattr(image_processor, "size"): if isinstance(image_processor.size, dict): # for Qwen params["size"] = image_processor.size["max_edge"] elif isinstance(image_processor.size, list): # for InternVL params["size"] = image_processor.size else: params["size"] = image_processor.size # image_max_pixels用于确保图像不会太大,超出内存 if params.get("size", None) is not None and hasattr(image_processor, "image_max_pixels"): area = image_processor.image_max_pixels # 来自training_args.image_max_pixels size = int(math.sqrt(area)) params["size"] = size return params

video_max_pixels

python
# 在src/llamafactory/data/mm_plugin.py中定义 # 控制视频帧处理时的最大像素数量 # 类似image_max_pixels,但应用于视频帧 def _get_mm_inputs( self, images: list["ImageInput"], videos: list["VideoInput"], audios: list["AudioInput"], processor: "ProcessorMixin", **kwargs, ) -> dict[str, "torch.Tensor"]: # ... if len(videos) != 0: # video_max_pixels用来限制视频帧的尺寸 if hasattr(processor, "video_max_pixels"): video_size = int(math.sqrt(processor.video_max_pixels)) image_kwargs["size"] = video_size # ...

trust_remote_code

python
# 在src/llamafactory/train/trainer.py中使用 # 允许从Hugging Face加载模型时执行远程代码 # 对于一些需要自定义代码的模型(如InternVL)是必需的 def create_model_and_tokenizer( args: Dict[str, Any], training_args: Optional[TrainingArguments] = None ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: # ... config_kwargs = { "trust_remote_code": args.get("trust_remote_code", False), "cache_dir": args.get("cache_dir", None), } # ... model = AutoModelForCausalLM.from_pretrained( model_name_or_path, config=config, torch_dtype=torch_dtype, trust_remote_code=args.get("trust_remote_code", False), # ... )

lora_rank

python
# 在src/llamafactory/train/peft_trainer.py中使用 # LoRA适配器的秩,决定了低秩适配矩阵的维度,影响模型参数量和表达能力 def get_peft_config( train_args: Dict[str, Any], model_args: Dict[str, Any], model: PreTrainedModel ) -> Dict[str, Any]: # ... if finetuning_type == "lora": lora_config = { "lora_alpha": train_args.get("lora_alpha", 16), "lora_dropout": train_args.get("lora_dropout", 0.05), "r": train_args.get("lora_rank", 8), # lora_rank决定了LoRA矩阵的秩 "bias": train_args.get("lora_bias", "none"), "target_modules": train_args.get("lora_target", "all"), } # ...

lora_target

python
# 在src/llamafactory/train/peft_trainer.py中使用 # 指定应用LoRA的模块名称列表或特殊值"all" # "all"表示自动检测并应用到所有线性层 def get_peft_config( train_args: Dict[str, Any], model_args: Dict[str, Any], model: PreTrainedModel ) -> Dict[str, Any]: # ... if finetuning_type == "lora": lora_config = { # ... "target_modules": train_args.get("lora_target", "all"), # 指定应用LoRA的模块 } if lora_config["target_modules"] == "all": # 对所有支持的模块应用LoRA lora_config["target_modules"] = find_all_linear_modules(model)

cutoff_len

python
# 在src/llamafactory/data/template.py和相关数据处理文件中使用 # 决定了输入序列的最大长度,超过该长度的输入会被截断 def preprocess_supervised_dataset( examples: Dict[str, List[Any]], tokenizer: PreTrainedTokenizer, template: Template ) -> Dict[str, List[List[int]]]: # ... tokenized_inputs = tokenizer( sources, padding=False, truncation=True, max_length=data_args.cutoff_len, # 使用cutoff_len限制输入长度 # ... )

max_samples

python
# 在src/llamafactory/data/dataset.py中使用 # 限制训练时使用的最大样本数量,可用于快速测试或限制训练集大小 def preprocess_dataset( dataset_path: str | list[str], tokenizer: PreTrainedTokenizer, data_args: DataArguments ) -> Dataset: # ... if data_args.max_samples is not None and data_args.max_samples > 0: # 限制样本数量 dataset = dataset.select(range(min(len(dataset), data_args.max_samples)))

overwrite_cache

python
# 在src/llamafactory/data/dataset.py中使用 # 决定是否覆盖之前缓存的预处理数据集 def load_dataset_from_path( dataset_path: str | list[str], data_args: DataArguments, splits: list[str], **kwargs ) -> Dataset: # ... streaming_kwargs = { "streaming": data_args.streaming, "cache_dir": data_args.cache_dir, "keep_in_memory": False, "download_mode": ("force_redownload" if data_args.overwrite_cache else None), } # overwrite_cache决定是否强制重新下载和处理数据

logging_steps

python
# 在src/llamafactory/train/trainer.py中使用 # 定义每隔多少步记录一次训练日志(损失等信息) def get_train_args(args: Dict[str, Any]) -> TrainingArguments: # ... return TrainingArguments( # ... logging_dir=os.path.join(args["output_dir"], "logging"), logging_strategy=args.get("logging_strategy", "steps"), logging_steps=args.get("logging_steps", 10), # 每10步记录一次日志 # ... )

save_steps

python
# 在src/llamafactory/train/trainer.py中使用 # 定义每隔多少步保存一次模型检查点 def get_train_args(args: Dict[str, Any]) -> TrainingArguments: # ... return TrainingArguments( # ... save_strategy=args.get("save_strategy", "steps"), save_steps=args.get("save_steps", 500), # 每500步保存一次检查点 # ... )

overwrite_output_dir

python
# 在src/llamafactory/train/trainer.py中使用 # 决定是否覆盖已存在的输出目录 def get_train_args(args: Dict[str, Any]) -> TrainingArguments: # ... return TrainingArguments( # ... overwrite_output_dir=args.get("overwrite_output_dir", False), # 是否覆盖输出目录 # ... )

save_only_model

python
# 在src/llamafactory/train/trainer.py中使用 # 决定是只保存模型还是保存整个训练状态(包括优化器状态等) def get_train_args(args: Dict[str, Any]) -> TrainingArguments: # ... return TrainingArguments( # ... save_only_model=args.get("save_only_model", False), # 是否只保存模型而不保存训练状态 # ... )

report_to

python
# 在src/llamafactory/train/trainer.py中使用 # 定义训练报告发送到哪些平台(如wandb, tensorboard等) def get_train_args(args: Dict[str, Any]) -> TrainingArguments: # ... return TrainingArguments( # ... report_to=args.get("report_to", None), # 指定报告平台 # ... )

resume_from_checkpoint

python
# 在src/llamafactory/train/trainer.py中使用 # 指定从哪个检查点恢复训练 def main(): # ... trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) # 如果resume_from_checkpoint不为None,将从指定检查点恢复训练
如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!