SharegptDatasetConverter
是LLaMA-Factory中的一个核心数据处理组件,专门用于处理ShareGPT格式的对话数据(包括OpenAI格式)。这个转换器将各种形式的对话数据转换为统一的内部格式,方便后续处理。下面几个方面可以帮助你理解它的工作原理:
代码开始定义了一个tag_mapping
字典,将数据中的角色标签(例如"user"、"assistant")映射到内部使用的枚举值。例如,对于OpenAI格式,它会将"user"映射到Role.USER.value,将"assistant"映射到Role.ASSISTANT.value等。这保证了不同数据集之间角色表示的一致性。
代码还设置了严格的消息结构验证规则:
如果数据不符合这些规则,会被标记为"broken_data",并跳过处理。
代码特别处理了系统提示(system prompt)。它会检查第一条消息是否是系统提示,如果是,就将其提取出来单独存储。这对于包含系统提示的OpenAI格式尤为重要。
转换器能够处理三种不同类型的样本:
对于每种类型,转换器会相应地构建prompt
和response
数据结构。
最后,转换器处理多模态数据(图片、视频、音频):
_find_medias
方法找到媒体文件的路径转换后的数据被统一为一个包含以下字段的字典:
_prompt
:输入提示部分_response
:模型响应部分_system
:系统提示_tools
:工具描述_images
/_videos
/_audios
:多模态媒体数据这种统一的格式极大地简化了后续的训练过程,使模型能够一致地处理不同来源和格式的数据。
SharegptDatasetConverter
python
@dataclass
class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
self.dataset_attr.observation_tag: Role.OBSERVATION.value,
self.dataset_attr.function_tag: Role.FUNCTION.value,
self.dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag)
even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = example[self.dataset_attr.messages]
if (
self.dataset_attr.system_tag
and len(messages) != 0
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
):
system = messages[0][self.dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[self.dataset_attr.system] if self.dataset_attr.system else ""
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True
break
aligned_messages.append(
{
"role": tag_mapping[message[self.dataset_attr.role_tag]],
"content": message[self.dataset_attr.content_tag],
}
)
if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True
if broken_data:
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if example[self.dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
self.dataset_attr.ranking
and isinstance(example[self.dataset_attr.chosen], dict)
and isinstance(example[self.dataset_attr.rejected], dict)
): # pairwise example
chosen = example[self.dataset_attr.chosen]
rejected = example[self.dataset_attr.rejected]
if (
chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages
response = [
{
"role": tag_mapping[chosen[self.dataset_attr.role_tag]],
"content": chosen[self.dataset_attr.content_tag],
},
{
"role": tag_mapping[rejected[self.dataset_attr.role_tag]],
"content": rejected[self.dataset_attr.content_tag],
},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
}
return output
python
@dataclass
class DatasetAttr:
r"""Dataset attributes."""
# basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False
# extra configs
subset: Optional[str] = None
split: str = "train"
folder: Optional[str] = None
num_samples: Optional[int] = None
# common columns
system: Optional[str] = None
tools: Optional[str] = None
images: Optional[str] = None
videos: Optional[str] = None
audios: Optional[str] = None
# dpo columns
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
# alpaca columns
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
# sharegpt columns
messages: Optional[str] = "conversations"
# sharegpt tags
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def join(self, attr: dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr)
self.set_attr("split", attr, default="train")
self.set_attr("folder", attr)
self.set_attr("num_samples", attr)
if "columns" in attr:
column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
for column_name in column_names:
self.set_attr(column_name, attr["columns"])
if "tags" in attr:
tag_names = ["role_tag", "content_tag"]
tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
for tag in tag_names:
self.set_attr(tag, attr["tags"])
本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!