diff --git a/examples/generate/generate_fill_in_blank_qa/README.md b/examples/generate/generate_fill_in_blank_qa/README.md new file mode 100644 index 00000000..47ae6bd3 --- /dev/null +++ b/examples/generate/generate_fill_in_blank_qa/README.md @@ -0,0 +1,3 @@ +# Generate Fill-in-blank QAs + +Fill-in-blank question answering (QA) involves creating questions where a key piece of information is omitted, requiring the respondent to fill in the missing word or phrase. This format is commonly used in educational assessments to test knowledge and comprehension. diff --git a/examples/generate/generate_fill_in_blank_qa/fill_in_blank_config.yaml b/examples/generate/generate_fill_in_blank_qa/fill_in_blank_config.yaml new file mode 100644 index 00000000..cbd534a5 --- /dev/null +++ b/examples/generate/generate_fill_in_blank_qa/fill_in_blank_config.yaml @@ -0,0 +1,80 @@ +global_params: + working_dir: cache + graph_backend: kuzu # graph database backend, support: kuzu, networkx + kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read_files # id is unique in the pipeline, and can be referenced by other steps + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples + + - id: chunk_documents + op_name: chunk + type: map_batch + dependencies: + - read_files + execution_params: + replicas: 4 + params: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: quiz + op_name: quiz + type: map_batch + dependencies: + - build_kg + execution_params: + replicas: 1 + batch_size: 128 + params: + quiz_samples: 2 # number of quiz samples to generate + + - id: judge + op_name: judge + type: map_batch + dependencies: + - quiz + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - judge + params: + method: ece # ece is a custom partition method based on comprehension loss + method_params: + max_units_per_community: 20 # max nodes and edges per community + min_units_per_community: 5 # min nodes and edges per community + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + save_output: true # save output + params: + method: fill_in_blank + num_of_questions: 5 + data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/examples/generate/generate_fill_in_blank_qa/generate_fill_in_blank.sh b/examples/generate/generate_fill_in_blank_qa/generate_fill_in_blank.sh new file mode 100644 index 00000000..8911410f --- /dev/null +++ b/examples/generate/generate_fill_in_blank_qa/generate_fill_in_blank.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_fill_in_blank_qa/fill_in_blank_config.yaml diff --git a/examples/generate/generate_multi_answer_qa/README.md b/examples/generate/generate_multi_answer_qa/README.md new file mode 100644 index 00000000..bc975ef5 --- /dev/null +++ b/examples/generate/generate_multi_answer_qa/README.md @@ -0,0 +1,3 @@ +# Generate Multi-Answer QAs + +Multi-answer question answering (QA) involves generating questions that can have multiple valid answers. This is particularly useful in educational settings, surveys, and research where diverse perspectives are valuable. \ No newline at end of file diff --git a/examples/generate/generate_multi_answer_qa/generate_multi_answer.sh b/examples/generate/generate_multi_answer_qa/generate_multi_answer.sh new file mode 100644 index 00000000..71b305c2 --- /dev/null +++ b/examples/generate/generate_multi_answer_qa/generate_multi_answer.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_multi_answer_qa/multi_answer_config.yaml diff --git a/examples/generate/generate_multi_answer_qa/multi_answer_config.yaml b/examples/generate/generate_multi_answer_qa/multi_answer_config.yaml new file mode 100644 index 00000000..bf6a1eaf --- /dev/null +++ b/examples/generate/generate_multi_answer_qa/multi_answer_config.yaml @@ -0,0 +1,80 @@ +global_params: + working_dir: cache + graph_backend: kuzu # graph database backend, support: kuzu, networkx + kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read_files # id is unique in the pipeline, and can be referenced by other steps + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples + + - id: chunk_documents + op_name: chunk + type: map_batch + dependencies: + - read_files + execution_params: + replicas: 4 + params: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: quiz + op_name: quiz + type: map_batch + dependencies: + - build_kg + execution_params: + replicas: 1 + batch_size: 128 + params: + quiz_samples: 2 # number of quiz samples to generate + + - id: judge + op_name: judge + type: map_batch + dependencies: + - quiz + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - judge + params: + method: ece # ece is a custom partition method based on comprehension loss + method_params: + max_units_per_community: 20 # max nodes and edges per community + min_units_per_community: 5 # min nodes and edges per community + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + save_output: true # save output + params: + method: multi_answer + num_of_questions: 5 + data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/examples/generate/generate_multi_choice_qa/README.md b/examples/generate/generate_multi_choice_qa/README.md new file mode 100644 index 00000000..aed4fa29 --- /dev/null +++ b/examples/generate/generate_multi_choice_qa/README.md @@ -0,0 +1,3 @@ +# Generate Multi-Choice QAs + +Multi-choice question answering (QA) tasks involve providing a question along with several answer options, where the goal is to select the correct answer from the given choices. diff --git a/examples/generate/generate_multi_choice_qa/generate_multi_choice.sh b/examples/generate/generate_multi_choice_qa/generate_multi_choice.sh new file mode 100644 index 00000000..bd3b9eff --- /dev/null +++ b/examples/generate/generate_multi_choice_qa/generate_multi_choice.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_multi_choice_qa/multi_choice_config.yaml diff --git a/examples/generate/generate_multi_choice_qa/multi_choice_config.yaml b/examples/generate/generate_multi_choice_qa/multi_choice_config.yaml new file mode 100644 index 00000000..91b88174 --- /dev/null +++ b/examples/generate/generate_multi_choice_qa/multi_choice_config.yaml @@ -0,0 +1,80 @@ +global_params: + working_dir: cache + graph_backend: kuzu # graph database backend, support: kuzu, networkx + kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read_files # id is unique in the pipeline, and can be referenced by other steps + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples + + - id: chunk_documents + op_name: chunk + type: map_batch + dependencies: + - read_files + execution_params: + replicas: 4 + params: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: quiz + op_name: quiz + type: map_batch + dependencies: + - build_kg + execution_params: + replicas: 1 + batch_size: 128 + params: + quiz_samples: 2 # number of quiz samples to generate + + - id: judge + op_name: judge + type: map_batch + dependencies: + - quiz + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - judge + params: + method: ece # ece is a custom partition method based on comprehension loss + method_params: + max_units_per_community: 20 # max nodes and edges per community + min_units_per_community: 5 # min nodes and edges per community + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + save_output: true # save output + params: + method: multi_choice + num_of_questions: 5 + data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py index 85de5877..b0186167 100644 --- a/graphgen/bases/base_generator.py +++ b/graphgen/bases/base_generator.py @@ -46,38 +46,47 @@ async def generate( def format_generation_results( results: list[dict], output_data_format: str ) -> list[dict[str, Any]]: - if output_data_format == "Alpaca": - results = [ - { - "instruction": v["question"], - "input": "", - "output": v["answer"], - } - for item in results - for k, v in item.items() - ] - elif output_data_format == "Sharegpt": - results = [ - { - "conversations": [ - {"from": "human", "value": v["question"]}, - {"from": "gpt", "value": v["answer"]}, - ] - } - for item in results - for k, v in item.items() - ] - elif output_data_format == "ChatML": - results = [ - { - "messages": [ - {"role": "user", "content": v["question"]}, - {"role": "assistant", "content": v["answer"]}, - ] - } - for item in results - for k, v in item.items() - ] - else: - raise ValueError(f"Unknown output data format: {output_data_format}") - return results + + flat_results = [] + for item in results: + for _, qa_data in item.items(): + question = qa_data.get("question", "") + answer = qa_data.get("answer", "") + if "options" in qa_data and qa_data["options"]: + options = qa_data["options"] + options_str = "\n".join( + [f"{key}. {options[key]}" for key in sorted(options.keys())] + ) + question += f"\nOptions:\n{options_str}" + + if output_data_format == "Alpaca": + flat_results.append( + { + "instruction": question, + "input": "", + "output": answer, + } + ) + elif output_data_format == "Sharegpt": + flat_results.append( + { + "conversations": [ + {"from": "human", "value": question}, + {"from": "gpt", "value": answer}, + ] + } + ) + elif output_data_format == "ChatML": + flat_results.append( + { + "messages": [ + {"role": "user", "content": question}, + {"role": "assistant", "content": answer}, + ] + } + ) + else: + raise ValueError( + f"Unknown output data format: {output_data_format}" + ) + return flat_results diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 43d38bed..1f083088 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -11,6 +11,9 @@ AggregatedGenerator, AtomicGenerator, CoTGenerator, + FillInBlankGenerator, + MultiAnswerGenerator, + MultiChoiceGenerator, MultiHopGenerator, QuizGenerator, VQAGenerator, diff --git a/graphgen/models/generator/__init__.py b/graphgen/models/generator/__init__.py index 49f8979c..6ccd077c 100644 --- a/graphgen/models/generator/__init__.py +++ b/graphgen/models/generator/__init__.py @@ -1,6 +1,9 @@ from .aggregated_generator import AggregatedGenerator from .atomic_generator import AtomicGenerator from .cot_generator import CoTGenerator +from .fill_in_blank_generator import FillInBlankGenerator +from .multi_answer_generator import MultiAnswerGenerator +from .multi_choice_generator import MultiChoiceGenerator from .multi_hop_generator import MultiHopGenerator from .quiz_generator import QuizGenerator from .vqa_generator import VQAGenerator diff --git a/graphgen/models/generator/fill_in_blank_generator.py b/graphgen/models/generator/fill_in_blank_generator.py new file mode 100644 index 00000000..c2f43898 --- /dev/null +++ b/graphgen/models/generator/fill_in_blank_generator.py @@ -0,0 +1,99 @@ +import re +from typing import Any + +from graphgen.bases import BaseGenerator +from graphgen.templates import FILL_IN_BLANK_GENERATION_PROMPT +from graphgen.utils import compute_content_hash, detect_main_language, logger + + +class FillInBlankGenerator(BaseGenerator): + def __init__(self, llm_client, num_of_questions) -> None: + super().__init__(llm_client) + self.num_of_questions = num_of_questions + + @staticmethod + def parse_response(response: str) -> Any: + """ + Parse fill-in-the-blank QA pairs from the LLM response. + Each QA pair contains question text with placeholders and the correct answer(s). + + :param response: The LLM response containing XML-formatted QA pairs + :return: Dictionary mapping question hash to question data, where each + value is a dict with "question", "answer", and "answers" keys + """ + qa_pairs = {} + + # Extract all QA pair blocks + qa_blocks = re.findall(r"(.*?)", response, re.DOTALL) + + if not qa_blocks: + logger.warning("No QA pairs found in response: %s", response) + return {} + + for block in qa_blocks: + # Extract and clean question text + q_match = re.search(r"(.*?)", block, re.DOTALL) + if not q_match: + logger.warning("Failed to parse question from block: %s", block) + continue + question = q_match.group(1).strip().strip('"').strip("'") + + # Extract and clean answer text + ans_match = re.search(r"(.*?)", block, re.DOTALL) + if not ans_match: + logger.warning("Failed to parse answer from block: %s", block) + continue + + answer_text = ans_match.group(1).strip().strip('"').strip("'") + + # Parse multiple answers (e.g., "A8X, 八百万" or "A8X") + # Split by comma and strip whitespace from each answer + answers = [ans.strip() for ans in answer_text.split(",") if ans.strip()] + + # Ensure at least one valid answer + if len(answers) == 0: + logger.warning("No valid answers found in: %s", answer_text) + continue + + # Build result entry with question hash as key + question_hash = compute_content_hash(question) + qa_pairs[question_hash] = { + "question": question, + "answer": answer_text, # Original answer text with commas + "answers": answers, # List of individual answers: ["A8X"] or ["A8X", "八百万"] + } + + logger.debug( + "Successfully parsed fill-in-the-blank question: %s", question[:50] + ) + + if not qa_pairs: + logger.error("Failed to parse any valid QA pairs from response") + + return qa_pairs + + # pylint: disable=W0221 + def build_prompt( + self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + nodes, edges = batch + entities_str = "\n".join( + [ + f"{index + 1}. {node[0]}: {node[1]['description']}" + for index, node in enumerate(nodes) + ] + ) + + relationships_str = "\n".join( + [ + f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" + for index, edge in enumerate(edges) + ] + ) + context = entities_str + "\n" + relationships_str + language = detect_main_language(entities_str + relationships_str) + prompt = FILL_IN_BLANK_GENERATION_PROMPT[language].format( + context=context, + num_of_questions=self.num_of_questions, + ) + return prompt diff --git a/graphgen/models/generator/multi_answer_generator.py b/graphgen/models/generator/multi_answer_generator.py new file mode 100644 index 00000000..b5a0db5c --- /dev/null +++ b/graphgen/models/generator/multi_answer_generator.py @@ -0,0 +1,118 @@ +import re +from typing import Any + +from graphgen.bases import BaseGenerator +from graphgen.templates import MAQ_GENERATION_PROMPT +from graphgen.utils import compute_content_hash, detect_main_language, logger + + +class MultiAnswerGenerator(BaseGenerator): + def __init__(self, llm_client, num_of_questions) -> None: + super().__init__(llm_client) + self.num_of_questions = num_of_questions + + @staticmethod + def parse_response(response: str) -> Any: + """ + Parse multiple-answer QA pairs from the LLM response. + Each QA pair contains question text, four options, and the correct answers (one or more). + + :param response: The LLM response containing XML-formatted QA pairs + :return: Dictionary mapping question hash to question data, where each + value is a dict with "question", "options", and "answer" keys + """ + qa_pairs = {} + + # Extract all QA pair blocks + qa_blocks = re.findall(r"(.*?)", response, re.DOTALL) + + if not qa_blocks: + logger.warning("No QA pairs found in response: %s", response) + return {} + + for block in qa_blocks: + # Extract and clean question text + q_match = re.search(r"(.*?)", block, re.DOTALL) + if not q_match: + logger.warning("Failed to parse question from block: %s", block) + continue + question = q_match.group(1).strip().strip('"').strip("'") + + # Extract and parse options (A, B, C, D) + opt_match = re.search(r"(.*?)", block, re.DOTALL) + if not opt_match: + logger.warning("Failed to parse options from block: %s", block) + continue + + options = {} + options_text = opt_match.group(1).strip() + for line in options_text.split("\n"): + line = line.strip() + if not line: + continue + # Match patterns like "A. text" or "B. text" + if m := re.match(r"^([A-Z])[.\s]\s*(.*)$", line): + letter, text = m.groups() + options[letter] = text.strip() + + # Extract and validate answer + ans_match = re.search(r"(.*?)", block, re.DOTALL) + if not ans_match: + logger.warning("Failed to parse answer from block: %s", block) + continue + answer_text = ans_match.group(1).strip().strip('"').strip("'") + answers = [ans.strip().upper() for ans in answer_text.split(",") if ans.strip()] + invalid_answers = [ans for ans in answers if ans not in options] + if invalid_answers: + logger.warning( + "Answers %s not found in options: %s", + invalid_answers, + list(options.keys()), + ) + continue + + # Ensure at least one valid answer + if len(answers) == 0: + logger.warning("No valid answers found in: %s", answer_text) + continue + + # Build result entry with question hash as key + question_hash = compute_content_hash(question) + qa_pairs[question_hash] = { + "question": question, + "options": options, # Dict like {"A": "text", "B": "text", ...} + "answer": ", ".join(answers), + } + + logger.debug("Successfully parsed MAQ: %s", question[:50]) + + if not qa_pairs: + logger.error("Failed to parse any valid MAQ pairs from response") + + return qa_pairs + + # pylint: disable=W0221 + def build_prompt( + self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + nodes, edges = batch + entities_str = "\n".join( + [ + f"{index + 1}. {node[0]}: {node[1]['description']}" + for index, node in enumerate(nodes) + ] + ) + + relationships_str = "\n".join( + [ + f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" + for index, edge in enumerate(edges) + ] + ) + context = entities_str + "\n" + relationships_str + language = detect_main_language(entities_str + relationships_str) + prompt = MAQ_GENERATION_PROMPT[language].format( + context=context, + num_of_questions=self.num_of_questions, + ) + return prompt diff --git a/graphgen/models/generator/multi_choice_generator.py b/graphgen/models/generator/multi_choice_generator.py new file mode 100644 index 00000000..fcac2e1b --- /dev/null +++ b/graphgen/models/generator/multi_choice_generator.py @@ -0,0 +1,118 @@ +import re +from typing import Any + +from graphgen.bases import BaseGenerator +from graphgen.templates import MCQ_GENERATION_PROMPT +from graphgen.utils import compute_content_hash, detect_main_language, logger + + +class MultiChoiceGenerator(BaseGenerator): + def __init__(self, llm_client, num_of_questions) -> None: + super().__init__(llm_client) + self.num_of_questions = num_of_questions + + @staticmethod + def parse_response(response: str) -> Any: + """ + Parse multiple choice QA pairs from the LLM response. + Each QA pair contains question text, four options, and the correct answer. + + :param response: The LLM response containing XML-formatted QA pairs + :return: Dictionary mapping question hash to question data, where each + value is a dict with "question", "options", and "answer" keys + """ + qa_pairs = {} + + # Extract all QA pair blocks + qa_blocks = re.findall(r"(.*?)", response, re.DOTALL) + + if not qa_blocks: + logger.warning("No QA pairs found in response: %s", response) + return {} + + for block in qa_blocks: + # Extract and clean question text + q_match = re.search(r"(.*?)", block, re.DOTALL) + if not q_match: + logger.warning("Failed to parse question from block: %s", block) + continue + question = q_match.group(1).strip().strip('"').strip("'") + + # Extract and parse options (A, B, C, D) + opt_match = re.search(r"(.*?)", block, re.DOTALL) + if not opt_match: + logger.warning("Failed to parse options from block: %s", block) + continue + + options = {} + options_text = opt_match.group(1).strip() + for line in options_text.split("\n"): + line = line.strip() + if not line: + continue + # Match patterns like "A. text" or "B. text" + if m := re.match(r"^([A-D])[.\s]\s*(.*)$", line): + letter, text = m.groups() + options[letter] = text.strip() + + # Validate options count + if len(options) != 4: + logger.warning( + "Expected 4 options, found %d: %s", len(options), options_text + ) + continue + + # Extract and validate answer + ans_match = re.search(r"(.*?)", block, re.DOTALL) + if not ans_match: + logger.warning("Failed to parse answer from block: %s", block) + continue + answer = ans_match.group(1).strip().strip('"').strip("'") + + # Ensure answer exists in options + if answer not in options: + logger.warning( + "Answer '%s' not found in options: %s", answer, list(options.keys()) + ) + continue + + # Build result entry with question hash as key + question_hash = compute_content_hash(question) + qa_pairs[question_hash] = { + "question": question, + "options": options, # Dict like {"A": "text", "B": "text", ...} + "answer": answer, # Single letter: "A", "B", "C", or "D" + } + + logger.debug("Successfully parsed MCQ: %s", question[:50]) + + if not qa_pairs: + logger.error("Failed to parse any valid MCQ pairs from response") + + return qa_pairs + + # pylint: disable=W0221 + def build_prompt( + self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + nodes, edges = batch + entities_str = "\n".join( + [ + f"{index + 1}. {node[0]}: {node[1]['description']}" + for index, node in enumerate(nodes) + ] + ) + + relationships_str = "\n".join( + [ + f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" + for index, edge in enumerate(edges) + ] + ) + context = entities_str + "\n" + relationships_str + language = detect_main_language(entities_str + relationships_str) + prompt = MCQ_GENERATION_PROMPT[language].format( + context=context, + num_of_questions=self.num_of_questions, + ) + return prompt diff --git a/graphgen/operators/generate/generate_service.py b/graphgen/operators/generate/generate_service.py index db784d08..ec7f0c2f 100644 --- a/graphgen/operators/generate/generate_service.py +++ b/graphgen/operators/generate/generate_service.py @@ -2,13 +2,6 @@ from graphgen.bases import BaseLLMWrapper, BaseOperator from graphgen.common import init_llm -from graphgen.models import ( - AggregatedGenerator, - AtomicGenerator, - CoTGenerator, - MultiHopGenerator, - VQAGenerator, -) from graphgen.utils import logger, run_concurrent @@ -22,6 +15,7 @@ def __init__( working_dir: str = "cache", method: str = "aggregated", data_format: str = "ChatML", + **generate_kwargs, ): super().__init__(working_dir=working_dir, op_name="generate_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") @@ -30,15 +24,46 @@ def __init__( self.data_format = data_format if self.method == "atomic": + from graphgen.models import AtomicGenerator + self.generator = AtomicGenerator(self.llm_client) elif self.method == "aggregated": + from graphgen.models import AggregatedGenerator + self.generator = AggregatedGenerator(self.llm_client) elif self.method == "multi_hop": + from graphgen.models import MultiHopGenerator + self.generator = MultiHopGenerator(self.llm_client) elif self.method == "cot": + from graphgen.models import CoTGenerator + self.generator = CoTGenerator(self.llm_client) - elif self.method in ["vqa"]: + elif self.method == "vqa": + from graphgen.models import VQAGenerator + self.generator = VQAGenerator(self.llm_client) + elif self.method == "multi_choice": + from graphgen.models import MultiChoiceGenerator + + self.generator = MultiChoiceGenerator( + self.llm_client, + num_of_questions=generate_kwargs.get("num_of_questions", 5), + ) + elif self.method == "multi_answer": + from graphgen.models import MultiAnswerGenerator + + self.generator = MultiAnswerGenerator( + self.llm_client, + num_of_questions=generate_kwargs.get("num_of_questions", 3), + ) + elif self.method == "fill_in_blank": + from graphgen.models import FillInBlankGenerator + + self.generator = FillInBlankGenerator( + self.llm_client, + num_of_questions=generate_kwargs.get("num_of_questions", 5), + ) else: raise ValueError(f"Unsupported generation mode: {method}") diff --git a/graphgen/templates/__init__.py b/graphgen/templates/__init__.py index cbfa4e17..a9e8cd85 100644 --- a/graphgen/templates/__init__.py +++ b/graphgen/templates/__init__.py @@ -6,10 +6,12 @@ AGGREGATED_GENERATION_PROMPT, ATOMIC_GENERATION_PROMPT, COT_GENERATION_PROMPT, + FILL_IN_BLANK_GENERATION_PROMPT, + MAQ_GENERATION_PROMPT, + MCQ_GENERATION_PROMPT, MULTI_HOP_GENERATION_PROMPT, VQA_GENERATION_PROMPT, ) from .kg import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT, MMKG_EXTRACTION_PROMPT -from .question_generation import QUESTION_GENERATION_PROMPT from .search_judgement import SEARCH_JUDGEMENT_PROMPT from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT diff --git a/graphgen/templates/generation/__init__.py b/graphgen/templates/generation/__init__.py index b58c2b6c..3554d42f 100644 --- a/graphgen/templates/generation/__init__.py +++ b/graphgen/templates/generation/__init__.py @@ -1,5 +1,8 @@ from .aggregated_generation import AGGREGATED_GENERATION_PROMPT from .atomic_generation import ATOMIC_GENERATION_PROMPT from .cot_generation import COT_GENERATION_PROMPT +from .fill_in_blank_generation import FILL_IN_BLANK_GENERATION_PROMPT +from .multi_answer_generation import MAQ_GENERATION_PROMPT +from .multi_choice_generation import MCQ_GENERATION_PROMPT from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT from .vqa_generation import VQA_GENERATION_PROMPT diff --git a/graphgen/templates/generation/classification_generation.py b/graphgen/templates/generation/classification_generation.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/templates/generation/fill_in_blank_generation.py b/graphgen/templates/generation/fill_in_blank_generation.py new file mode 100644 index 00000000..edc1c323 --- /dev/null +++ b/graphgen/templates/generation/fill_in_blank_generation.py @@ -0,0 +1,78 @@ +TEMPLATE_ZH = """请根据上下文资料生成独立的知识问答填空题。填空题的答案必须能在原文中直接找到。 + +生成要求: +1. **语言一致性**:若上下文资料为中文,则生成中文问题;若为英文,则生成英文问题 +2. **数量**:每个上下文资料生成{num_of_questions}个填空题 +3. **独立性**:每个问题必须完整独立,不依赖其他问题 +4. **准确性**:正确答案必须能从原文直接得出 +5. **占位符格式**:使用________(四个下划线)作为填空占位符 + +输出格式: + + +问题文本(使用________作为占位符) +正确答案文本(多个空用逗号分隔) + + + +示例(根据iPad Air 2生成2题): + + +iPad Air 2 是由________制造的? +美国苹果公司(Apple) + + +iPad Air 2 的发布日期是________,上市日期是________。 +2014年10月16日,2014年10月22日 + + + + +上下文资料: +{{context}} + +请为以下资料生成{num_of_questions}个填空题: +""" + + +TEMPLATE_EN = """Generate independent fill-in-the-blank questions based on the provided context. \ +Answers must be directly derivable from the text. + +Requirements: +1. **Language Consistency**: Generate in the same language as the context (Chinese/English) +2. **Quantity**: Generate {num_of_questions} questions per context +3. **Independence**: Each question must be self-contained +4. **Accuracy**: Correct answer must be directly found in the source text +5. **Placeholder Format**: Use ________ (four underscores) as the blank placeholder + +Output Format: + + +Question text (use ________ as placeholder) +Correct answer text (separate multiple blanks with commas) + + + +Example (2 questions): + + +The iPad Air 2 was manufactured by ________? +Apple Inc. + + +The iPad Air 2 was released on ________ and launched on ________. +October 16, 2014, October 22, 2014 + + + +Context: +{{context}} + +Please generate {num_of_questions} fill-in-the-blank questions for the following context: +""" + + +FILL_IN_BLANK_GENERATION_PROMPT = { + "zh": TEMPLATE_ZH, + "en": TEMPLATE_EN, +} diff --git a/graphgen/templates/generation/multi_answer_generation.py b/graphgen/templates/generation/multi_answer_generation.py new file mode 100644 index 00000000..04408fe9 --- /dev/null +++ b/graphgen/templates/generation/multi_answer_generation.py @@ -0,0 +1,100 @@ +TEMPLATE_ZH = """请根据上下文资料生成独立的知识问答不定项选择题,每个选择题包含四个选项,其中有若干个正确答案(至少一个),其他为干扰项。 + +生成要求: +1. **语言一致性**:若上下文资料为中文,则生成中文问题;若为英文,则生成英文问题 +2. **数量**:每个上下文资料生成{num_of_questions}个选择题 +3. **独立性**:每个问题必须完整独立,不依赖其他问题 +4. **准确性**:正确答案必须能从原文直接得出,干扰项需合理且有区分度 +5. **答案格式**:当有多个正确答案时,用逗号分隔选项字母,如"A, B, C" + +输出格式: + + +问题文本 +A. 选项A文本 +B. 选项B文本 +C. 选项C文本 +D. 选项D文本 +正确答案选项字母(多个答案用逗号分隔) + + + +示例(根据iPad Air 2生成2题): + + +iPad Air 2的发布年份是? +A. 2012年 +B. 2014年 +C. 2015年 +D. 2017年 +B + + +以下哪些是 iPad Air 2 的特点? +A. Touch ID指纹识别功能 +B. A8X高效处理器 +C. 十百万像素前置相机 +D. 八百万像素后置相机镜头 +A, B, D + + + + +上下文资料: +{context} + +请为以下资料生成{num_of_questions}个不定项选择题: +""" + + +TEMPLATE_EN = """Generate independent multiple-select knowledge questions \ +based on the provided context. Each question should contain four options \ +with one or more correct answers and distractors. + +Requirements: +1. **Language Consistency**: Generate in the same language as the context (Chinese/English) +2. **Quantity**: Generate {num_of_questions} questions per context +3. **Independence**: Each question must be self-contained +4. **Accuracy**: Correct answer(s) must be derivable from text, distractors should be plausible +5. **Answer Format**: For multiple correct answers, separate option letters with commas, e.g., "A, B, C" + +Output Format: + + +Question text +A. Option A text +B. Option B text +C. Option C text +D. Option D text +Correct option letter(s) (separate multiple answers with commas) + + + +Example (2 questions): + + +What are the features of iPad Air 2? +A. Touch ID fingerprint recognition +B. A8X processor +C. Ten-megapixel front camera +D. Eight-megapixel rear camera +A, B, D + + +When was iPad Air 2 discontinued? +A. March 21, 2016 +B. March 21, 2017 +C. October 22, 2017 +D. October 16, 2016 +B + + + +Context: +{context} + +Please generate {num_of_questions} multiple-select questions for the following context: +""" + + +MAQ_GENERATION_PROMPT = {"zh": TEMPLATE_ZH, "en": TEMPLATE_EN} diff --git a/graphgen/templates/generation/multi_choice_generation.py b/graphgen/templates/generation/multi_choice_generation.py new file mode 100644 index 00000000..b69fa6ab --- /dev/null +++ b/graphgen/templates/generation/multi_choice_generation.py @@ -0,0 +1,97 @@ +TEMPLATE_GENERATION_ZH: str = """请根据上下文资料生成独立的知识问答单选题,每个选择题包含四个选项,其中仅有一个正确答案,其他三个为干扰项。 + +生成要求: +1. **语言一致性**:若上下文资料为中文,则生成中文问题;若为英文,则生成英文问题 +2. **数量**:每个上下文资料生成{num_of_questions}个选择题 +3. **独立性**:每个问题必须完整独立,不依赖其他问题 +4. **准确性**:正确答案必须能从原文直接得出,干扰项需合理且有区分度 + +输出格式: + + +问题文本 +A. 选项A文本 +B. 选项B文本 +C. 选项C文本 +D. 选项D文本 +正确答案选项字母 + + + +示例(根据iPad Air 2生成2题): + + +iPad Air 2的发布年份是? +A. 2012年 +B. 2014年 +C. 2015年 +D. 2017年 +B + + +iPad Air 2搭载的处理器型号是? +A. A8 +B. A9X +C. A8X +D. A10 +C + + + + +上下文资料: +{context} + +请为以下资料生成{num_of_questions}个选择题: +""" + +TEMPLATE_GENERATION_EN: str = """Generate independent multiple-choice questions \ +based on the provided context. Each question should contain four options \ +with only one correct answer and three distractors. + +Requirements: +1. **Language Consistency**: Generate in the same language as the context (Chinese/English) +2. **Quantity**: Generate {num_of_questions} questions per context +3. **Independence**: Each question must be self-contained +4. **Accuracy**: Correct answer must be derivable from text, distractors should be plausible + +Output Format: + + +Question text +A. Option A text +B. Option B text +C. Option C text +D. Option D text +Correct option letter + + + +Example (2 questions): + + +What year was the iPad Air 2 released? +A. 2012 +B. 2014 +C. 2015 +D. 2017 +B + + +Which processor does iPad Air 2 use? +A. A8 +B. A9X +C. A8X +D. A10 +C + + + +Context: +{context} + +Please generate {num_of_questions} questions for the following context: +""" + + +MCQ_GENERATION_PROMPT = {"zh": TEMPLATE_GENERATION_ZH, "en": TEMPLATE_GENERATION_EN} diff --git a/graphgen/templates/question_generation.py b/graphgen/templates/question_generation.py deleted file mode 100644 index e75bf169..00000000 --- a/graphgen/templates/question_generation.py +++ /dev/null @@ -1,32 +0,0 @@ -# pylint: disable=C0301 - - -# TODO: 修改这里的prompt -TEMPLATE_MULTI_EN = """You are an assistant to help read a article and then rephrase it in a question answering format. The user will provide you with an article with its content. You need to generate a paraphrase of the same article in question and answer format with one tag of "Question: ..." followed by "Answer: ...". Remember to keep the meaning and every content of the article intact. - -Here is the format you should follow for your response: -Question: -Answer: - -Here is the article you need to rephrase: -{doc} -""" - -TEMPLATE_MULTI_ZH = """你是一位助手,帮助阅读一篇文章,然后以问答格式重述它。用户将为您提供一篇带有内容的文章。你需要以一个标签"问题:..."为开头,接着是"答案:...",生成一篇与原文章相同的问答格式的重述。请确保保持文章的意义和每个内容不变。 - -以下是你应该遵循的响应格式: -问题: <问题> -答案: <答案> - -以下是你需要重述的文章: -{doc} -""" - -QUESTION_GENERATION_PROMPT = { - "English": { - "MULTI_TEMPLATE": TEMPLATE_MULTI_EN, - }, - "Chinese": { - "MULTI_TEMPLATE": TEMPLATE_MULTI_ZH, - }, -}