编辑
2025-05-08
深度学习
00

目录

SharegptDatasetConverter 类解析
1. 角色标签映射
2. 消息结构验证
3. 系统提示处理
4. 多种样本类型处理
5. 多模态数据处理
整体输出格式

SharegptDatasetConverter 类解析

SharegptDatasetConverter是LLaMA-Factory中的一个核心数据处理组件,专门用于处理ShareGPT格式的对话数据(包括OpenAI格式)。这个转换器将各种形式的对话数据转换为统一的内部格式,方便后续处理。下面几个方面可以帮助你理解它的工作原理:

1. 角色标签映射

代码开始定义了一个tag_mapping字典,将数据中的角色标签(例如"user"、"assistant")映射到内部使用的枚举值。例如,对于OpenAI格式,它会将"user"映射到Role.USER.value,将"assistant"映射到Role.ASSISTANT.value等。这保证了不同数据集之间角色表示的一致性。

2. 消息结构验证

代码还设置了严格的消息结构验证规则:

  • 用户和观察消息(odd_tags)必须出现在奇数位置
  • 助手和函数消息(even_tags)必须出现在偶数位置
  • 消息总数必须正确(普通对话应该是偶数条,排序对话应该是奇数条)

如果数据不符合这些规则,会被标记为"broken_data",并跳过处理。

3. 系统提示处理

代码特别处理了系统提示(system prompt)。它会检查第一条消息是否是系统提示,如果是,就将其提取出来单独存储。这对于包含系统提示的OpenAI格式尤为重要。

4. 多种样本类型处理

转换器能够处理三种不同类型的样本:

  • 普通对话样本:标准的用户-助手交互
  • KTO样本:包含人类反馈标签的样本,用于基于人类反馈的训练
  • 成对比较样本:包含选择和拒绝回答的样本,用于DPO训练

对于每种类型,转换器会相应地构建promptresponse数据结构。

5. 多模态数据处理

最后,转换器处理多模态数据(图片、视频、音频):

  • 通过_find_medias方法找到媒体文件的路径
  • 将媒体数据与文本对话关联起来
  • 正确处理各种路径格式,包括本地和远程路径

整体输出格式

转换后的数据被统一为一个包含以下字段的字典:

  • _prompt:输入提示部分
  • _response:模型响应部分
  • _system:系统提示
  • _tools:工具描述
  • _images/_videos/_audios:多模态媒体数据

这种统一的格式极大地简化了后续的训练过程,使模型能够一致地处理不同来源和格式的数据。

SharegptDatasetConverter

python
@dataclass class SharegptDatasetConverter(DatasetConverter): def __call__(self, example: dict[str, Any]) -> dict[str, Any]: tag_mapping = { self.dataset_attr.user_tag: Role.USER.value, self.dataset_attr.assistant_tag: Role.ASSISTANT.value, self.dataset_attr.observation_tag: Role.OBSERVATION.value, self.dataset_attr.function_tag: Role.FUNCTION.value, self.dataset_attr.system_tag: Role.SYSTEM.value, } odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag) even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag) accept_tags = (odd_tags, even_tags) messages = example[self.dataset_attr.messages] if ( self.dataset_attr.system_tag and len(messages) != 0 and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag ): system = messages[0][self.dataset_attr.content_tag] messages = messages[1:] else: system = example[self.dataset_attr.system] if self.dataset_attr.system else "" aligned_messages = [] broken_data = False for turn_idx, message in enumerate(messages): if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: logger.warning_rank0(f"Invalid role tag in {messages}.") broken_data = True break aligned_messages.append( { "role": tag_mapping[message[self.dataset_attr.role_tag]], "content": message[self.dataset_attr.content_tag], } ) if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( self.dataset_attr.ranking and len(aligned_messages) % 2 == 0 ): logger.warning_rank0(f"Invalid message count in {messages}.") broken_data = True if broken_data: logger.warning_rank0("Skipping this abnormal example.") prompt, response = [], [] elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example prompt = aligned_messages[:-1] response = aligned_messages[-1:] if example[self.dataset_attr.kto_tag]: response = response + [{"role": Role.ASSISTANT.value, "content": ""}] else: response = [{"role": Role.ASSISTANT.value, "content": ""}] + response elif ( self.dataset_attr.ranking and isinstance(example[self.dataset_attr.chosen], dict) and isinstance(example[self.dataset_attr.rejected], dict) ): # pairwise example chosen = example[self.dataset_attr.chosen] rejected = example[self.dataset_attr.rejected] if ( chosen[self.dataset_attr.role_tag] not in accept_tags[-1] or rejected[self.dataset_attr.role_tag] not in accept_tags[-1] ): logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.") broken_data = True prompt = aligned_messages response = [ { "role": tag_mapping[chosen[self.dataset_attr.role_tag]], "content": chosen[self.dataset_attr.content_tag], }, { "role": tag_mapping[rejected[self.dataset_attr.role_tag]], "content": rejected[self.dataset_attr.content_tag], }, ] else: # normal example prompt = aligned_messages[:-1] response = aligned_messages[-1:] output = { "_prompt": prompt, "_response": response, "_system": system, "_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "", "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, } return output
python
@dataclass class DatasetAttr: r"""Dataset attributes.""" # basic configs load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"] dataset_name: str formatting: Literal["alpaca", "sharegpt"] = "alpaca" ranking: bool = False # extra configs subset: Optional[str] = None split: str = "train" folder: Optional[str] = None num_samples: Optional[int] = None # common columns system: Optional[str] = None tools: Optional[str] = None images: Optional[str] = None videos: Optional[str] = None audios: Optional[str] = None # dpo columns chosen: Optional[str] = None rejected: Optional[str] = None kto_tag: Optional[str] = None # alpaca columns prompt: Optional[str] = "instruction" query: Optional[str] = "input" response: Optional[str] = "output" history: Optional[str] = None # sharegpt columns messages: Optional[str] = "conversations" # sharegpt tags role_tag: Optional[str] = "from" content_tag: Optional[str] = "value" user_tag: Optional[str] = "human" assistant_tag: Optional[str] = "gpt" observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" system_tag: Optional[str] = "system" def __repr__(self) -> str: return self.dataset_name def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None: setattr(self, key, obj.get(key, default)) def join(self, attr: dict[str, Any]) -> None: self.set_attr("formatting", attr, default="alpaca") self.set_attr("ranking", attr, default=False) self.set_attr("subset", attr) self.set_attr("split", attr, default="train") self.set_attr("folder", attr) self.set_attr("num_samples", attr) if "columns" in attr: column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"] column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"] for column_name in column_names: self.set_attr(column_name, attr["columns"]) if "tags" in attr: tag_names = ["role_tag", "content_tag"] tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"] for tag in tag_names: self.set_attr(tag, attr["tags"])
如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!