From 57842f06151059f3ecd109e8e50688456ac571ba Mon Sep 17 00:00:00 2001 From: Bill Ray Date: Tue, 17 Oct 2023 17:41:09 -0400 Subject: [PATCH 1/3] Initial comit of language model lib --- align_system/algorithms/llm_chat_baseline.py | 17 +- align_system/language_model_lib/__init__.py | 15 ++ .../language_model_lib/chat_langauge_model.py | 115 +++++++++++ align_system/language_model_lib/dialog.py | 70 +++++++ .../language_model_lib/dialog_tokenizer.py | 58 ++++++ .../language_model_lib/language_model.py | 160 +++++++++++++++ .../llama_2_kdma_predicting_adm.py | 187 ++++++++++++++++++ .../test_chat_language_model.py | 39 ++++ .../language_model_lib/test_language_model.py | 54 +++++ align_system/language_model_lib/util.py | 81 ++++++++ 10 files changed, 790 insertions(+), 6 deletions(-) create mode 100644 align_system/language_model_lib/__init__.py create mode 100644 align_system/language_model_lib/chat_langauge_model.py create mode 100644 align_system/language_model_lib/dialog.py create mode 100644 align_system/language_model_lib/dialog_tokenizer.py create mode 100644 align_system/language_model_lib/language_model.py create mode 100644 align_system/language_model_lib/llama_2_kdma_predicting_adm.py create mode 100644 align_system/language_model_lib/test_chat_language_model.py create mode 100644 align_system/language_model_lib/test_language_model.py create mode 100644 align_system/language_model_lib/util.py diff --git a/align_system/algorithms/llm_chat_baseline.py b/align_system/algorithms/llm_chat_baseline.py index 8b5b7c27..3c739421 100644 --- a/align_system/algorithms/llm_chat_baseline.py +++ b/align_system/algorithms/llm_chat_baseline.py @@ -88,12 +88,17 @@ def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', prec self.tokenizer = None - def load_model(self): - print('Loading model:', self.hf_model) - self.model = AutoModelForCausalLM.from_pretrained(self.hf_model, torch_dtype=self.precision) - self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model) - - self.model = self.model.to(self.device) + def load_model(self, model=None, tokenizer=None): + assert (model is None) == (tokenizer is None), "model and tokenizer must both be None or both be not None." + if model is not None: + print('Loading model and tokenizer from provided objects.') + self.model = model + self.tokenizer = tokenizer + else: + print('Loading model:', self.hf_model) + self.model = AutoModelForCausalLM.from_pretrained(self.hf_model, torch_dtype=self.precision) + self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model) + self.model = self.model.to(self.device) def get_character_ids(self, character_str): diff --git a/align_system/language_model_lib/__init__.py b/align_system/language_model_lib/__init__.py new file mode 100644 index 00000000..0f451679 --- /dev/null +++ b/align_system/language_model_lib/__init__.py @@ -0,0 +1,15 @@ +from importlib import reload + +def reload_all(): + # Import the modules inside this function to ensure they're available for reloading + + from . import util + from . import language_model as lm + from . import dialog_tokenizer as dt + from . import chat_langauge_model as clm + from . import llama_2_kdma_predicting_adm as kpa + + + # Reload in the correct order + for module in [util, lm, dt, clm, kpa]: + reload(module) diff --git a/align_system/language_model_lib/chat_langauge_model.py b/align_system/language_model_lib/chat_langauge_model.py new file mode 100644 index 00000000..ec508ebe --- /dev/null +++ b/align_system/language_model_lib/chat_langauge_model.py @@ -0,0 +1,115 @@ +from align_system.language_model_lib.language_model import LanguageModel +from align_system.language_model_lib.dialog_tokenizer import dialog_tokenizers +from align_system.language_model_lib.util import read_file, format_template, dialog_from_string, dialog_to_string + + +class ChatLanguageModel(LanguageModel): + + def __init__(self, model, tokenizer): + super().__init__(model, tokenizer) + model_name = model.name_or_path + assert model_name in dialog_tokenizers, f'No dialog tokenizer found for model {model_name}' + self.dialog_tokenizer = dialog_tokenizers[model_name](tokenizer) + + def generate_responses(self, dialogs, log_file=None, max_new_tokens=512, temperature=0.6): + if log_file is not None: + log_file.write('**Dialogs:**\n') + for i, dialog in enumerate(dialogs): + log_file.write(f'*Dialog {i}:*\n{dialog_to_string(dialog)}\n') + log_file.flush() + # Remove the last dialog piece if it is an assistant response + # Use the assistant response as a prefix + user_last_dialogs = [] + prefixes = [] + for dialog in dialogs: + prefix = '' + if dialog[-1]['role'] == 'assistant': + prefix = dialog[-1]['content'] + dialog = dialog[:-1] + user_last_dialogs.append(dialog) + prefixes.append(prefix) + dialogs = user_last_dialogs + + prompt_token_lists = [ + [self.dialog_tokenizer.dialog_to_tokens(dialog)] + for dialog in dialogs + ] + + for prompt_tokens, prefix in zip(prompt_token_lists, prefixes): + if len(prefix) > 0: + prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) + prompt_tokens[0] += prefix_tokens + + prompt_token_lists = [x[0] for x in prompt_token_lists] + responses = self.generate_from_tokens(prompt_token_lists, max_new_tokens=max_new_tokens, temperature=temperature) + + prefixed_responses = [ + f'{prefix}{response}' + for prefix, response in zip(prefixes, responses) + ] + + if log_file is not None: + log_file.write('**Generated Responses:**\n') + for i, response in enumerate(prefixed_responses): + log_file.write(f'*Response {i}:*\n{response}\n') + log_file.flush() + + return prefixed_responses + + + def generate_from_template( + self, + template_files, + substitution_dicts, + parse_generation_fn=None, + batch_size=5, + log_file=None, + max_tokens=512, + temperature=0.6, + max_retry=10, + verbose=False + ): + if type(substitution_dicts) is dict: + substitution_dicts = [substitution_dicts] + + if type(template_files) is str: + template_files = [template_files] * len(substitution_dicts) + + assert len(template_files) == len(substitution_dicts), 'Number of templates and substitutions do not match' + + dialogs = { + i: dialog_from_string(format_template(read_file(template_file), **substitutions)) + for i, (template_file, substitutions) in enumerate(zip(template_files, substitution_dicts)) + } + + outputs = {} + input_counts = {} + while len(dialogs) > 0: + sample_ids = list(dialogs.keys())[:batch_size] + batch = [dialogs[i] for i in sample_ids] + generations = self.generate_responses(batch, log_file=log_file, max_new_tokens=max_tokens, temperature=temperature) + + for sample_id, generation in zip(sample_ids, generations): + input_counts[sample_id] = input_counts.get(sample_id, 0) + 1 + if input_counts[sample_id] > max_retry: + raise Exception(f'Could not generate valid output for sample [{sample_id}]') + + if parse_generation_fn is not None: + try: + outputs[sample_id] = parse_generation_fn(generation) + del dialogs[sample_id] + except Exception as e: + if verbose: + print(f'Error: could not parse output for sample [{sample_id}]') + print(e) + pass + else: + outputs[sample_id] = generation + del dialogs[sample_id] + + assert len(outputs) == len(substitution_dicts), 'Unexpected state: number of outputs and substitutions do not match' + + return [ + outputs[i] + for i in range(len(outputs)) + ] \ No newline at end of file diff --git a/align_system/language_model_lib/dialog.py b/align_system/language_model_lib/dialog.py new file mode 100644 index 00000000..5934f122 --- /dev/null +++ b/align_system/language_model_lib/dialog.py @@ -0,0 +1,70 @@ +import re + + +class Dialog: + + @classmethod + def from_string(cls, string): + dialog_markers = { + '=== system': 'system', + '=== user': 'user', + '=== assistant': 'assistant', + } + dialog = [] + lines = string.split('\n') + current_role = '' + current_content = '' + for line in lines: + if line.strip() in dialog_markers: + if current_role and current_content: + dialog.append({ + 'role': current_role, + 'content': current_content.strip() + }) + current_role = dialog_markers[line.strip()] + current_content = '' + else: + current_content += f'{line}\n' + if current_role and current_content: + dialog.append({ + 'role': current_role, + 'content': current_content.strip() + }) + return dialog + + + @classmethod + def from_template(cls, template, **substitutions): + cls.from_string(format_template(template, **substitutions)) + + + def __init__(self, messages): + self.messages = messages + + + def __str__(self): + output = '' + + for dialog_piece in self.messages: + role = dialog_piece['role'] + content = dialog_piece['content'] + output += f"=== {role}\n" + output += f"{content}\n" + + return output + + +def format_template(template, **substitutions): + for key, value in substitutions.items(): + key = '{{%s}}' % key + if not key in template: + raise Exception(f'Could not find key {key} in template') + template = template.replace(key, value) + + # ensure there are no strings sorrounded by {{ }} + matches = re.findall(r'{{.*?}}', template) + # if there are any matches, raise an exception + if len(matches) > 0: + raise Exception(f'Could not find values for {matches} in template') + + return template \ No newline at end of file diff --git a/align_system/language_model_lib/dialog_tokenizer.py b/align_system/language_model_lib/dialog_tokenizer.py new file mode 100644 index 00000000..b558a974 --- /dev/null +++ b/align_system/language_model_lib/dialog_tokenizer.py @@ -0,0 +1,58 @@ +from abc import abstractmethod + +class DialogTokenizer: + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + @abstractmethod + def dialog_to_tokens(self, dialog_messages): + pass + + +class Llama2DialogTokenizer(DialogTokenizer): + + + def dialog_to_tokens(self, dialog_messages): + # Define instance and system borders + B_INST, E_INST = "[INST]", "[/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + + # If the role of the first message is system + if dialog_messages[0]["role"] == "system": + # Create an initial dialog entry combining system and user messages + system_dialog = {"role": dialog_messages[1]["role"], + "content": B_SYS + dialog_messages[0]["content"] + E_SYS + dialog_messages[1]["content"]} + # Update dialog to start with system_dialog and followed by the rest of the dialog + dialog_messages = [system_dialog] + dialog_messages[2:] + + # Ensure the correct dialog order (system, user, assistant, user, assistant... ) + assert all([msg["role"] == "user" for msg in dialog_messages[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog_messages[1::2]]), \ + "Model only supports 'system', 'user' and 'assistant' roles, in the sequence (s/u/a/u/a...)" + + # Encode each user message and its following assistant message into tokens + dialog_tokens = [] + for prompt, answer in zip(dialog_messages[::2], dialog_messages[1::2]): + tokenized_message = ([self.tokenizer.bos_token_id] + + self.tokenizer.encode(f"{B_INST} {prompt['content'].strip()} {E_INST} {answer['content'].strip()} ", + add_special_tokens=False) + + [self.tokenizer.eos_token_id]) + dialog_tokens.extend(tokenized_message) + + # Ensure the final message is from the user + assert dialog_messages[-1]["role"] == "user", "Last message must be from the user." + + # Encode the user's final message into tokens and add to dialog_tokens + user_final_message_tokens = ([self.tokenizer.bos_token_id] + self.tokenizer.encode( + f"{B_INST} {dialog_messages[-1]['content'].strip()} {E_INST}", + add_special_tokens=False)) + dialog_tokens.extend(user_final_message_tokens) + + return dialog_tokens + + +dialog_tokenizers = { + 'meta-llama/Llama-2-7b-chat-hf': Llama2DialogTokenizer, + 'meta-llama/Llama-2-13b-chat-hf': Llama2DialogTokenizer, +} \ No newline at end of file diff --git a/align_system/language_model_lib/language_model.py b/align_system/language_model_lib/language_model.py new file mode 100644 index 00000000..b5f47c98 --- /dev/null +++ b/align_system/language_model_lib/language_model.py @@ -0,0 +1,160 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from typing import List + +class LanguageModel: + """ + A class that handles transformers Language Models + """ + + @classmethod + def load_model(cls, hf_model_name: str, precision: torch.dtype = torch.float32, device: str = 'cuda') -> 'LanguageModel': + """ + Loads the specified transformer model and tokenizer. + + Args: + hf_model_name (str): The huggingface model name. + precision (torch.dtype, optional): The precision of the model weights. Defaults to torch.float32. + device (str, optional): The device to move the model to. Defaults to 'cuda'. + + Returns: + LanguageModel: An instance of this class with the loaded model and tokenizer. + """ + model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype=precision) + tokenizer = AutoTokenizer.from_pretrained(hf_model_name) + model = model.to(device) + return cls(model, tokenizer) + + + def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> None: + """ + Initializes the LanguageModel instance with the given model and tokenizer. + + Args: + model (AutoModelForCausalLM): The loaded transformer model. + tokenizer (AutoTokenizer): The loaded tokenizer. + """ + self.model = model + self.tokenizer = tokenizer + + + def generate_from_tokens(self, prompt_token_lists: List[List[int]], log_file=None, max_new_tokens: int=512, temperature: float=0.6, padding='left'): + """ + Generates text from a list of tokenized prompts. + + Args: + prompt_token_lists (List[List[int]]): A batch of lists where each list is a sequence of tokens. + max_new_tokens (int, optional): The maximum number of tokens to generate. Defaults to 512. + temperature (float, optional): The temperature for the generation algorithm. Defaults to 0.6. + + Returns: + List[str]: The generated text for each prompt in the input list. Only contains text after the prompt. + """ + prompt_token_lists = [ + torch.tensor(prompt_tokens).to(self.model.device).unsqueeze(0) + for prompt_tokens in prompt_token_lists + ] + + max_length = max([prompt_tokens.size(1) for prompt_tokens in prompt_token_lists]) + + pad_token_id = self.tokenizer.pad_token_id + # Pad each sequence to the max length + assert padding == 'left' or padding == 'right', f"Padding must be either 'left' or 'right', got {padding}" + pad_fn = lambda prompt_token_size: (max_length - prompt_token_size, 0) if padding == 'left' else (0, max_length - prompt_token_size) + + padded_prompt_token_lists = [ + torch.nn.functional.pad(prompt_tokens, pad_fn(prompt_tokens.size(1)), value=pad_token_id) + for prompt_tokens in prompt_token_lists + ] + + attention_masks = [ + torch.nn.functional.pad(torch.ones_like(prompt_tokens), pad_fn(prompt_tokens.size(1)), value=0) + for prompt_tokens in prompt_token_lists + ] + + position_ids = [ + torch.nn.functional.pad(torch.arange(prompt_tokens.size(1)).unsqueeze(0), pad_fn(prompt_tokens.size(1)), value=0) + for prompt_tokens in prompt_token_lists + ] + + + # Stack the padded sequences + stacked_prompt_tokens = torch.cat(padded_prompt_token_lists, dim=0) + stacked_attention_masks = torch.cat(attention_masks, dim=0) + stacked_position_ids = torch.cat(position_ids, dim=0) + + if log_file is not None: + prompt_texts = [ + self.tokenizer.decode(prompt_tokens.squeeze(0), skip_special_tokens=True) + for prompt_tokens in padded_prompt_token_lists + ] + log_file.write('**Prompt texts:**\n') + for i, prompt_text in enumerate(prompt_texts): + log_file.write(f'Prompt {i}:\n{prompt_text}\n') + + log_file.flush() + + + + # Generate outputs for all dialogs in a batch + # TODO ensure the batch size is not too large for the GPU + outputs = self.model.generate( + stacked_prompt_tokens, + attention_mask=stacked_attention_masks, + # position_ids=stacked_position_ids, # TODO figure out why including the position ids breaks the model + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=max_new_tokens, + temperature=temperature + ) + + # Decode the generated outputs + decoded_outputs = [ + self.tokenizer.decode(output_tokens[len(prompt_tokens.squeeze(0)):], skip_special_tokens=True) + for output_tokens, prompt_tokens in zip(outputs.sequences, padded_prompt_token_lists) + ] + + if log_file is not None: + log_file.write('**Generated texts:**\n') + for i, decoded_output in enumerate(decoded_outputs): + log_file.write(f'*Generation {i}:*\n{decoded_output}\n') + log_file.flush() + + return decoded_outputs + + + + def generate(self, prompt_texts: List[str], log_file=None, max_new_tokens: int=512, temperature: float=0.6): + """ + Generates text from a list of prompts. + + Args: + prompt_texts (List[str]): A list of prompts. + max_new_tokens (int, optional): The maximum number of tokens to generate. Defaults to 512. + temperature (float, optional): The temperature for the generation algorithm. Defaults to 0.6. + + Returns: + List[str]: The generated text for each prompt in the input list. Only contains text after the prompt. + """ + # Convert text prompts to token prompts + prompt_token_lists = [self.tokenizer.encode(prompt_text) for prompt_text in prompt_texts] + return self.generate_from_tokens(prompt_token_lists, log_file, max_new_tokens, temperature) + + + def generate_with_prefixes(self, prompt_texts: List[str], prefixes: List[str], log_file=None, max_new_tokens: int=512, temperature: float=0.6): + """ + Generates text from a list of prompts with a list of prefixes. + + Args: + prompt_texts (List[str]): A list of prompts. + prefixes (List[str]): A list of prefixes. + max_new_tokens (int, optional): The maximum number of tokens to generate. Defaults to 512. + temperature (float, optional): The temperature for the generation algorithm. Defaults to 0.6. + + Returns: + List[str]: The generated text for each prompt in the input list. Includes the prefix but not the prompt. + """ + combined_texts = [f'{prompt}{prefix}' for prompt, prefix in zip(prompt_texts, prefixes)] + generations = self.generate(combined_texts, log_file, max_new_tokens, temperature) + return [f'{prefix}{generation}' for prefix, generation in zip(prefixes, generations)] \ No newline at end of file diff --git a/align_system/language_model_lib/llama_2_kdma_predicting_adm.py b/align_system/language_model_lib/llama_2_kdma_predicting_adm.py new file mode 100644 index 00000000..bfa4e940 --- /dev/null +++ b/align_system/language_model_lib/llama_2_kdma_predicting_adm.py @@ -0,0 +1,187 @@ +import json + +from align_system.language_model_lib.chat_langauge_model import ChatLanguageModel +from align_system.language_model_lib.util import extract_kdma_description + +class Llama2KDMAPredictingADM(ChatLanguageModel): + + def predict_outcomes( + self, + scenario, + probe, + choices, + log_file=None, + max_tokens=512, + temperature=0.6, + outcome_template_file='templates/predict_outcomes.md' + ): + return self.generate_from_template( + outcome_template_file, + [ + { + 'scenario': scenario, + 'probe': probe, + 'choice': choice, + } + for choice in choices + ], + log_file=log_file, + max_tokens=max_tokens, + temperature=temperature + ) + + + def predict_kdma_scores( + self, + scenario_text, + probe_text, + choice_texts, + predicted_outcomes=None, + generate_reasoning=True, + log_file=None, + max_new_tokens=512, + temperature=0.6, + kdma_template_file='templates/kdma.md', + kdma_descriptions_file='templates/bbn_kdma_descriptions.md', + ): + choice_ids = [f'choice_{i}' for i in range(len(choice_texts))] + substitutions = [] + info = [] + kdma_descriptions = extract_kdma_description(kdma_descriptions_file) + if predicted_outcomes is None: + predicted_outcomes = [None] * len(choice_texts) + for choice_id, choice, outcome in zip(choice_ids, choice_texts, predicted_outcomes): + for kdma, kdma_description in kdma_descriptions.items(): + substitution = { + 'kdma': kdma, + 'kdma_description': kdma_description, + 'scenario': scenario_text, + 'probe': probe_text, + 'choice': choice, + } + + if outcome is not None: + substitution['outcome'] = outcome + + substitutions.append(substitution) + info.append((choice_id, kdma)) + + def parse_kdma_score_response(response): + if generate_reasoning: + start_idx = response.find('{') + end_idx = response.rfind('}') + response_json = json.loads(response[start_idx:end_idx+1]) + assert 'score' in response_json, 'score not found in response' + assert 'reasoning' in response_json, 'reasoning not found in response' + else: + # find the first numeric character + char = None + for c in response: + if c.isnumeric(): + char = c + break + assert char is not None, 'Could not find numeric character in response' + response_json = { + 'score': float(response[response.find(char):]) + } + + return response_json + + generations = self.generate_from_template( + kdma_template_file, + substitutions, + parse_kdma_score_response, + log_file=log_file, + max_tokens=max_new_tokens, + temperature=temperature, + ) + + predicted_kdmas = {} + reasonings = {} + for (choice_id, kdma), generation in zip(info, generations): + predicted_choice_kdmas = predicted_kdmas.get(choice_id, {}) + predicted_kdmas[choice_id] = predicted_choice_kdmas + + choice_reasonings = reasonings.get(choice_id, {}) + reasonings[choice_id] = choice_reasonings + + predicted_choice_kdmas[kdma] = generation['score'] + + if generate_reasoning: + choice_reasonings[kdma] = generation['reasoning'] + + + predicted_kdmas = [ + predicted_kdmas[choice_id] + for choice_id in choice_ids + ] + if generate_reasoning: + reasonings = [ + reasonings[choice_id] + for choice_id in choice_ids + ] + + if generate_reasoning: + return predicted_kdmas, reasonings + else: + return predicted_kdmas + + def make_aligned_descision( + self, + scenario, + probe, + choices, + target_kdmas, + alignment_fn, + predict_outcomes=True, + generate_reasoning=True, + kdma_descriptions_file='templates/bbn_kdma_descriptions.md', + outcome_template_file='templates/predict_outcomes.md', + kdma_template_file='templates/predict_kdma_scores_reasoning.md', + ): + # Generate the outcomes for each choice + outcomes = None + if predict_outcomes: + outcomes = self.predict_outcomes( + scenario, + probe, + choices, + outcome_template_file=outcome_template_file + ) + + assert len(choices) == len(outcomes), 'Unexpected state: number of choices and outcomes do not match' + + # Get the scores and reasonings for each choice + predicted_kdma_scores = self.predict_kdma_scores( + scenario, + probe, + choices, + outcomes=outcomes, + kdma_template_file=kdma_template_file, + kdma_descriptions_file=kdma_descriptions_file + ) + + if generate_reasoning: + scores, reasonings = predicted_kdma_scores + else: + scores = predicted_kdma_scores + + assert len(choices) == len(scores), 'Unexpected state: number of choices and scores do not match' + + # Compute the similarity score for each choice + alignment_scores = [] + for score in scores: + alignment_scores.append(alignment_fn(target_kdmas, score)) + + max_idx = alignment_scores.index(max(alignment_scores)) + + justification = { + 'choice': choices[max_idx], + 'outcome': outcomes[max_idx], + 'kdma_scores': scores[max_idx], + } + + if generate_reasoning: + justification['kdma_reasonings'] = reasonings[max_idx] + + return max_idx, justification \ No newline at end of file diff --git a/align_system/language_model_lib/test_chat_language_model.py b/align_system/language_model_lib/test_chat_language_model.py new file mode 100644 index 00000000..c73d9292 --- /dev/null +++ b/align_system/language_model_lib/test_chat_language_model.py @@ -0,0 +1,39 @@ +import pytest + +from chat_langauge_model import ChatLanguageModel + +MODEL_TO_TEST = 'meta-llama/Llama-2-7b-chat-hf' + +@pytest.fixture(scope="module") +def chat_language_model(): + # Load the model once for all tests that use this fixture + return ChatLanguageModel.load_model(MODEL_TO_TEST) + + +def test_generate_responses(chat_language_model): + dialogs = [ + [ + {'role': 'system', 'content': 'speak like a pirate'}, + {'role': 'user', 'content': 'hello'}, + ], + [ + {'role': 'system', 'content': 'speak like a pirate'}, + {'role': 'user', 'content': 'hello'}, + {'role': 'assistant', 'content': 'What if you'}, + ], + [ + {'role': 'system', 'content': 'speak like a pirate'}, + {'role': 'user', 'content': 'hello'}, + {'role': 'assistant', 'content': 'What if you'}, + ] + ] + + responses = chat_language_model.generate_responses(dialogs, max_new_tokens=512, temperature=0.0001) + + assert type(responses) is list + assert len(responses) == len(responses) + assert type(responses[0]) is str + assert responses[1].startswith(dialogs[1][-1]['content']) + + + diff --git a/align_system/language_model_lib/test_language_model.py b/align_system/language_model_lib/test_language_model.py new file mode 100644 index 00000000..d22cf99e --- /dev/null +++ b/align_system/language_model_lib/test_language_model.py @@ -0,0 +1,54 @@ +import pytest +import torch + +from language_model import LanguageModel + +MODEL_TO_TEST = 'gpt2' # Use a smaller model for testing + +@pytest.fixture(scope="module") +def language_model(): + # Load the model once for all tests that use this fixture + return LanguageModel.load_model(MODEL_TO_TEST, device='cpu') + +def test_load_model(language_model): + assert language_model.model.dtype == torch.float32 + assert language_model.model.device.type == 'cpu' + + +def test_generate_from_tokens(language_model): + tokens = [ + [9246, 9703, 9246, 9703], + [1681, 146, 1681, 146, 1681], + ] + + generations = language_model.generate_from_tokens(tokens, max_new_tokens=1, temperature=0) + + assert generations == [ + 'cat', + '\n' + ] + +def test_generate(language_model): + prompts = [ + 'catdogcatdog', + 'ABCABCABCABCABC', + ] + generations = language_model.generate(prompts, max_new_tokens=1, temperature=0) + assert generations == [ + 'cat', + 'ABC', + ] + +def test_generate_with_prefixes(language_model): + prompts = [ + 'catdogcatdog', + 'ABCABCABCABCABC', + ] + prefixes = [ + 'cat', + 'ABC', + ] + generations = language_model.generate_with_prefixes(prompts, prefixes=prefixes, max_new_tokens=1, temperature=0) + + for generation, prefix in zip(generations, prefixes): + assert generation.startswith(prefix) \ No newline at end of file diff --git a/align_system/language_model_lib/util.py b/align_system/language_model_lib/util.py new file mode 100644 index 00000000..079ebcf8 --- /dev/null +++ b/align_system/language_model_lib/util.py @@ -0,0 +1,81 @@ +import re + +def dialog_from_string(dialog_string): + dialog_markers = { + '=== system': 'system', + '=== user': 'user', + '=== assistant': 'assistant', + } + dialog = [] + lines = dialog_string.split('\n') + current_role = '' + current_content = '' + for line in lines: + if line.strip() in dialog_markers: + if current_role and current_content: + dialog.append({ + 'role': current_role, + 'content': current_content.strip() + }) + current_role = dialog_markers[line.strip()] + current_content = '' + else: + current_content += f'{line}\n' + if current_role and current_content: + dialog.append({ + 'role': current_role, + 'content': current_content.strip() + }) + return dialog + +def read_file(file_path): + with open(file_path, 'r') as f: + return f.read() + +def format_template(template, **substitutions): + for key, value in substitutions.items(): + key = '{{%s}}' % key + if not key in template: + raise Exception(f'Could not find key {key} in template') + template = template.replace(key, value) + + # ensure there are no strings sorrounded by {{ }} + matches = re.findall(r'{{.*?}}', template) + # if there are any matches, raise an exception + if len(matches) > 0: + raise Exception(f'Unsubstituited key(s) in template: {matches}') + + return template + + +def extract_kdma_description(descriptions_file): + kdma_dict = {} + kdma_name = None + kdma_description = '' + + with open(descriptions_file, 'r') as f: + for line in f: + if line.startswith('#'): + if kdma_name is not None: + kdma_dict[kdma_name] = kdma_description.strip() + kdma_name = line[1:].strip() + kdma_description = '' + else: + kdma_description += line + + if kdma_name is not None and kdma_name not in kdma_dict: + kdma_dict[kdma_name] = kdma_description.strip() + + return kdma_dict + + +def dialog_to_string(dialog): + output = '' + + for dialog_piece in dialog: + role = dialog_piece['role'] + content = dialog_piece['content'] + output += f"=== {role}\n" + output += f"{content}\n" + + return output \ No newline at end of file From 15b4e3e1f6490f141106f2a27c1d92534a3d5010 Mon Sep 17 00:00:00 2001 From: Bill Ray Date: Tue, 17 Oct 2023 18:37:14 -0400 Subject: [PATCH 2/3] docstrings --- ...ngauge_model.py => chat_language_model.py} | 111 +++++++++---- align_system/language_model_lib/dialog.py | 70 -------- .../language_model_lib/dialog_tokenizer.py | 34 +++- .../language_model_lib/language_model.py | 121 +++++++------- .../llama_2_kdma_predicting_adm.py | 154 +++++++----------- .../test_chat_language_model.py | 2 +- .../language_model_lib/test_language_model.py | 2 +- align_system/language_model_lib/util.py | 64 ++++++-- 8 files changed, 281 insertions(+), 277 deletions(-) rename align_system/language_model_lib/{chat_langauge_model.py => chat_language_model.py} (56%) delete mode 100644 align_system/language_model_lib/dialog.py diff --git a/align_system/language_model_lib/chat_langauge_model.py b/align_system/language_model_lib/chat_language_model.py similarity index 56% rename from align_system/language_model_lib/chat_langauge_model.py rename to align_system/language_model_lib/chat_language_model.py index ec508ebe..0d8093aa 100644 --- a/align_system/language_model_lib/chat_langauge_model.py +++ b/align_system/language_model_lib/chat_language_model.py @@ -1,24 +1,46 @@ +from typing import List, Dict, Optional, Callable, Union, TextIO + from align_system.language_model_lib.language_model import LanguageModel from align_system.language_model_lib.dialog_tokenizer import dialog_tokenizers from align_system.language_model_lib.util import read_file, format_template, dialog_from_string, dialog_to_string - class ChatLanguageModel(LanguageModel): - - def __init__(self, model, tokenizer): + + def __init__(self, model: LanguageModel, tokenizer: Callable[[str], List[str]]): + """ + Initializes the chat language model. + + :param model: Pretrained language model. + :param tokenizer: Tokenizer function. + """ super().__init__(model, tokenizer) model_name = model.name_or_path assert model_name in dialog_tokenizers, f'No dialog tokenizer found for model {model_name}' self.dialog_tokenizer = dialog_tokenizers[model_name](tokenizer) - - def generate_responses(self, dialogs, log_file=None, max_new_tokens=512, temperature=0.6): + + def generate_responses(self, + dialogs: List[Dict[str, str]], + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6) -> List[str]: + """ + Generates responses for given dialogs. + + :param dialogs: List of dialogs. + :param log_file: Optional file to log the process. + :param max_new_tokens: Maximum number of new tokens to generate. + :param temperature: Temperature for sampling. + :return: Generated responses. + """ + # If logging is requested, write the dialogues into the log file if log_file is not None: log_file.write('**Dialogs:**\n') for i, dialog in enumerate(dialogs): log_file.write(f'*Dialog {i}:*\n{dialog_to_string(dialog)}\n') log_file.flush() - # Remove the last dialog piece if it is an assistant response - # Use the assistant response as a prefix + + # Prepare lists for the last user dialogues and prefixes. + # Prefix refers to the assistant's response in the last turn of a dialogue. user_last_dialogs = [] prefixes = [] for dialog in dialogs: @@ -28,72 +50,91 @@ def generate_responses(self, dialogs, log_file=None, max_new_tokens=512, tempera dialog = dialog[:-1] user_last_dialogs.append(dialog) prefixes.append(prefix) - dialogs = user_last_dialogs - + + # Tokenization step prompt_token_lists = [ [self.dialog_tokenizer.dialog_to_tokens(dialog)] - for dialog in dialogs + for dialog in user_last_dialogs ] + # Add the prefix tokens to the prompt tokens for prompt_tokens, prefix in zip(prompt_token_lists, prefixes): if len(prefix) > 0: prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) prompt_tokens[0] += prefix_tokens - + + # Generate responses using tokens prompt_token_lists = [x[0] for x in prompt_token_lists] responses = self.generate_from_tokens(prompt_token_lists, max_new_tokens=max_new_tokens, temperature=temperature) - prefixed_responses = [ f'{prefix}{response}' for prefix, response in zip(prefixes, responses) ] + # If logging is requested, write the generated responses into the log file if log_file is not None: log_file.write('**Generated Responses:**\n') for i, response in enumerate(prefixed_responses): log_file.write(f'*Response {i}:*\n{response}\n') log_file.flush() - + return prefixed_responses - - + def generate_from_template( self, - template_files, - substitution_dicts, - parse_generation_fn=None, - batch_size=5, - log_file=None, - max_tokens=512, - temperature=0.6, - max_retry=10, - verbose=False - ): - if type(substitution_dicts) is dict: + template_files: Union[List[str], str], + substitution_dicts: Union[List[Dict[str, str]], Dict[str, str]], + parse_generation_fn: Optional[Callable[[str], str]] = None, + batch_size: int = 5, + log_file: Optional[TextIO] = None, + max_tokens: int = 512, + temperature: float = 0.6, + max_retry: int = 10, + verbose: bool = False) -> List[str]: + """ + Generates responses for given templates with substitutions. + + :param template_files: Template files to use for generation. + :param substitution_dicts: Substitution dictionaries for the templates. + :param parse_generation_fn: Function to parse the generated responses. + :param batch_size: Batch size for generating responses. + :param log_file: Optional file to log the process. + :param max_tokens: Maximum number of tokens to generate. + :param temperature: Temperature for sampling. + :param max_retry: Maximum number of attempts to generate a valid output. + :param verbose: If True, verbose logging is enabled. + :return: Generated responses. + """ + if isinstance(substitution_dicts, dict): substitution_dicts = [substitution_dicts] - - if type(template_files) is str: + + if isinstance(template_files, str): template_files = [template_files] * len(substitution_dicts) - + assert len(template_files) == len(substitution_dicts), 'Number of templates and substitutions do not match' - + + # Create a dialogue for each template/substitution pair dialogs = { i: dialog_from_string(format_template(read_file(template_file), **substitutions)) for i, (template_file, substitutions) in enumerate(zip(template_files, substitution_dicts)) } - + outputs = {} input_counts = {} while len(dialogs) > 0: sample_ids = list(dialogs.keys())[:batch_size] batch = [dialogs[i] for i in sample_ids] generations = self.generate_responses(batch, log_file=log_file, max_new_tokens=max_tokens, temperature=temperature) - - for sample_id, generation in zip(sample_ids, generations): + + # Process the generated responses + for sample_id, generation in zip(sample_ids, generations): input_counts[sample_id] = input_counts.get(sample_id, 0) + 1 + + # If the maximum number of try-outs is exceeded, throw an error if input_counts[sample_id] > max_retry: raise Exception(f'Could not generate valid output for sample [{sample_id}]') - + + # If there's a specific function to parse the generations, try to apply it if parse_generation_fn is not None: try: outputs[sample_id] = parse_generation_fn(generation) @@ -108,7 +149,7 @@ def generate_from_template( del dialogs[sample_id] assert len(outputs) == len(substitution_dicts), 'Unexpected state: number of outputs and substitutions do not match' - + return [ outputs[i] for i in range(len(outputs)) diff --git a/align_system/language_model_lib/dialog.py b/align_system/language_model_lib/dialog.py deleted file mode 100644 index 5934f122..00000000 --- a/align_system/language_model_lib/dialog.py +++ /dev/null @@ -1,70 +0,0 @@ -import re - - -class Dialog: - - @classmethod - def from_string(cls, string): - dialog_markers = { - '=== system': 'system', - '=== user': 'user', - '=== assistant': 'assistant', - } - dialog = [] - lines = string.split('\n') - current_role = '' - current_content = '' - for line in lines: - if line.strip() in dialog_markers: - if current_role and current_content: - dialog.append({ - 'role': current_role, - 'content': current_content.strip() - }) - current_role = dialog_markers[line.strip()] - current_content = '' - else: - current_content += f'{line}\n' - if current_role and current_content: - dialog.append({ - 'role': current_role, - 'content': current_content.strip() - }) - return dialog - - - @classmethod - def from_template(cls, template, **substitutions): - cls.from_string(format_template(template, **substitutions)) - - - def __init__(self, messages): - self.messages = messages - - - def __str__(self): - output = '' - - for dialog_piece in self.messages: - role = dialog_piece['role'] - content = dialog_piece['content'] - output += f"=== {role}\n" - output += f"{content}\n" - - return output - - -def format_template(template, **substitutions): - for key, value in substitutions.items(): - key = '{{%s}}' % key - if not key in template: - raise Exception(f'Could not find key {key} in template') - template = template.replace(key, value) - - # ensure there are no strings sorrounded by {{ }} - matches = re.findall(r'{{.*?}}', template) - # if there are any matches, raise an exception - if len(matches) > 0: - raise Exception(f'Could not find values for {matches} in template') - - return template \ No newline at end of file diff --git a/align_system/language_model_lib/dialog_tokenizer.py b/align_system/language_model_lib/dialog_tokenizer.py index b558a974..89ac8a3f 100644 --- a/align_system/language_model_lib/dialog_tokenizer.py +++ b/align_system/language_model_lib/dialog_tokenizer.py @@ -1,19 +1,42 @@ from abc import abstractmethod +from typing import List, Dict +from transformers import PreTrainedTokenizerBase class DialogTokenizer: - - def __init__(self, tokenizer): + """ + Abstract base class for dialog tokenizers. + """ + def __init__(self, tokenizer: PreTrainedTokenizerBase): + """ + Initializes the dialog tokenizer. + + :param tokenizer: Pretrained tokenizer. + """ self.tokenizer = tokenizer @abstractmethod - def dialog_to_tokens(self, dialog_messages): + def dialog_to_tokens(self, dialog_messages: List[Dict[str, str]]) -> List[int]: + """ + Transforms a dialog to tokens. + + :param dialog_messages: List of dialogs. + :returns: List of tokens representing the dialog. + """ pass class Llama2DialogTokenizer(DialogTokenizer): + """ + Dialog tokenizer for Llama-2. + """ - - def dialog_to_tokens(self, dialog_messages): + def dialog_to_tokens(self, dialog_messages: List[Dict[str, str]]) -> List[int]: + """ + Transforms a dialog to tokens. Llama communicates using system, user and assistant roles. + + :param dialog_messages: List of dialogs. + :returns: List of tokens representing the dialog. + """ # Define instance and system borders B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" @@ -52,6 +75,7 @@ def dialog_to_tokens(self, dialog_messages): return dialog_tokens +# This mapping should ideally be updated when adding any new tokenizer classes to the project dialog_tokenizers = { 'meta-llama/Llama-2-7b-chat-hf': Llama2DialogTokenizer, 'meta-llama/Llama-2-13b-chat-hf': Llama2DialogTokenizer, diff --git a/align_system/language_model_lib/language_model.py b/align_system/language_model_lib/language_model.py index b5f47c98..1fe5de7c 100644 --- a/align_system/language_model_lib/language_model.py +++ b/align_system/language_model_lib/language_model.py @@ -1,56 +1,60 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer - -from typing import List +from typing import List, Union, Optional, TextIO class LanguageModel: """ - A class that handles transformers Language Models + Class to define the Language Model. """ @classmethod - def load_model(cls, hf_model_name: str, precision: torch.dtype = torch.float32, device: str = 'cuda') -> 'LanguageModel': + def load_model(cls, + hf_model_name: str, + precision: torch.dtype = torch.float32, + device: str = 'cuda') -> 'LanguageModel': """ - Loads the specified transformer model and tokenizer. - - Args: - hf_model_name (str): The huggingface model name. - precision (torch.dtype, optional): The precision of the model weights. Defaults to torch.float32. - device (str, optional): The device to move the model to. Defaults to 'cuda'. + Load the language model. - Returns: - LanguageModel: An instance of this class with the loaded model and tokenizer. + :param hf_model_name: Name of the model in Huggingface. + :param precision: Precision of the model's weights. + :param device: Device to run the model on. + :return: Initialized LanguageModel object. """ + # Load the model from Huggingface model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype=precision) tokenizer = AutoTokenizer.from_pretrained(hf_model_name) model = model.to(device) return cls(model, tokenizer) - - - def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> None: + + def __init__(self, + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer) -> None: """ - Initializes the LanguageModel instance with the given model and tokenizer. + Initializes the language model. - Args: - model (AutoModelForCausalLM): The loaded transformer model. - tokenizer (AutoTokenizer): The loaded tokenizer. + :param model: Pretrained Huggingface model. + :param tokenizer: Tokenizer from Huggingface. """ self.model = model self.tokenizer = tokenizer - - - def generate_from_tokens(self, prompt_token_lists: List[List[int]], log_file=None, max_new_tokens: int=512, temperature: float=0.6, padding='left'): - """ - Generates text from a list of tokenized prompts. - Args: - prompt_token_lists (List[List[int]]): A batch of lists where each list is a sequence of tokens. - max_new_tokens (int, optional): The maximum number of tokens to generate. Defaults to 512. - temperature (float, optional): The temperature for the generation algorithm. Defaults to 0.6. - - Returns: - List[str]: The generated text for each prompt in the input list. Only contains text after the prompt. + def generate_from_tokens(self, + prompt_token_lists: List[List[int]], + log_file: Union[None, str, object] = None, + max_new_tokens: int = 512, + temperature: float = 0.6, + padding: str='left') -> List[str]: + """ + Generates text from the given list of tokens. + + :param prompt_token_lists: List of lists of tokens to generate the text. + :param log_file: Path to the log file. + :param max_new_tokens: Maximum number of new tokens to be generated. + :param temperature: Temperature for probability adjustment. + :param padding: Padding direction, either 'left' or 'right'. + :return: Generated text. """ + # Move to the model's device and unpack prompt_token_lists = [ torch.tensor(prompt_tokens).to(self.model.device).unsqueeze(0) for prompt_tokens in prompt_token_lists @@ -59,10 +63,12 @@ def generate_from_tokens(self, prompt_token_lists: List[List[int]], log_file=Non max_length = max([prompt_tokens.size(1) for prompt_tokens in prompt_token_lists]) pad_token_id = self.tokenizer.pad_token_id - # Pad each sequence to the max length + + # Padding function for the desired direction assert padding == 'left' or padding == 'right', f"Padding must be either 'left' or 'right', got {padding}" pad_fn = lambda prompt_token_size: (max_length - prompt_token_size, 0) if padding == 'left' else (0, max_length - prompt_token_size) - + + # Pad each sequence to the max length padded_prompt_token_lists = [ torch.nn.functional.pad(prompt_tokens, pad_fn(prompt_tokens.size(1)), value=pad_token_id) for prompt_tokens in prompt_token_lists @@ -123,38 +129,39 @@ def generate_from_tokens(self, prompt_token_lists: List[List[int]], log_file=Non return decoded_outputs - - - def generate(self, prompt_texts: List[str], log_file=None, max_new_tokens: int=512, temperature: float=0.6): + def generate(self, + prompt_texts: List[str], + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6) -> List[str]: """ - Generates text from a list of prompts. + Generates text from the given list of inputs. - Args: - prompt_texts (List[str]): A list of prompts. - max_new_tokens (int, optional): The maximum number of tokens to generate. Defaults to 512. - temperature (float, optional): The temperature for the generation algorithm. Defaults to 0.6. - - Returns: - List[str]: The generated text for each prompt in the input list. Only contains text after the prompt. + :param prompt_texts: List of prompts to generate from. + :param log_file: Optional file object to write to + :param max_new_tokens: Maximum number of new tokens to be generated. + :param temperature: Temperature for probability adjustment. """ - # Convert text prompts to token prompts + # Convert the text to tokens and generate the text prompt_token_lists = [self.tokenizer.encode(prompt_text) for prompt_text in prompt_texts] return self.generate_from_tokens(prompt_token_lists, log_file, max_new_tokens, temperature) - - def generate_with_prefixes(self, prompt_texts: List[str], prefixes: List[str], log_file=None, max_new_tokens: int=512, temperature: float=0.6): + def generate_with_prefixes(self, + prompt_texts: List[str], + prefixes: List[str], + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6) -> List[str]: """ - Generates text from a list of prompts with a list of prefixes. - - Args: - prompt_texts (List[str]): A list of prompts. - prefixes (List[str]): A list of prefixes. - max_new_tokens (int, optional): The maximum number of tokens to generate. Defaults to 512. - temperature (float, optional): The temperature for the generation algorithm. Defaults to 0.6. + Generates text from the given list of inputs with prefixes. - Returns: - List[str]: The generated text for each prompt in the input list. Includes the prefix but not the prompt. + :param prompt_texts: List of prompts to generate from. + :param prefixes: List of prefixes to prepend to the generated text. + :param log_file: Optional file object to write to + :param max_new_tokens: Maximum number of new tokens to be generated. + :param temperature: Temperature for probability adjustment. """ + # Combine the inputs with prefixes and generate the text combined_texts = [f'{prompt}{prefix}' for prompt, prefix in zip(prompt_texts, prefixes)] generations = self.generate(combined_texts, log_file, max_new_tokens, temperature) - return [f'{prefix}{generation}' for prefix, generation in zip(prefixes, generations)] \ No newline at end of file + return [f'{prefix}{generation}' for prefix, generation in zip(prefixes, generations)] diff --git a/align_system/language_model_lib/llama_2_kdma_predicting_adm.py b/align_system/language_model_lib/llama_2_kdma_predicting_adm.py index bfa4e940..24fa4350 100644 --- a/align_system/language_model_lib/llama_2_kdma_predicting_adm.py +++ b/align_system/language_model_lib/llama_2_kdma_predicting_adm.py @@ -1,26 +1,36 @@ import json - -from align_system.language_model_lib.chat_langauge_model import ChatLanguageModel +from typing import Union, List, Dict, Tuple, Optional, TextIO +from align_system.language_model_lib.chat_language_model import ChatLanguageModel from align_system.language_model_lib.util import extract_kdma_description class Llama2KDMAPredictingADM(ChatLanguageModel): - - def predict_outcomes( - self, - scenario, - probe, - choices, - log_file=None, - max_tokens=512, - temperature=0.6, - outcome_template_file='templates/predict_outcomes.md' - ): + + def predict_outcomes(self, + scenario_text: str, + probe_text: str, + choices: List[str], + log_file: Optional[TextIO] = None, + max_tokens: int = 512, + temperature: float = 0.6, + outcome_template_file: str = 'templates/predict_outcomes.md') -> List[str]: + """ + Predicts outcomes for given scenario, probe and choices. + + :param scenario: Scenario text. + :param probe: Probe text. + :param choices: Choices text. + :param log_file: Optional log file. + :param max_tokens: Maximum number of tokens to generate. + :param temperature: Temperature for sampling. + :param outcome_template_file: Template file for Outcomes. + :return: List of generated predictions. + """ return self.generate_from_template( outcome_template_file, [ { - 'scenario': scenario, - 'probe': probe, + 'scenario': scenario_text, + 'probe': probe_text, 'choice': choice, } for choice in choices @@ -29,27 +39,41 @@ def predict_outcomes( max_tokens=max_tokens, temperature=temperature ) + - - def predict_kdma_scores( - self, - scenario_text, - probe_text, - choice_texts, - predicted_outcomes=None, - generate_reasoning=True, - log_file=None, - max_new_tokens=512, - temperature=0.6, - kdma_template_file='templates/kdma.md', - kdma_descriptions_file='templates/bbn_kdma_descriptions.md', - ): + def predict_kdma_scores(self, + scenario_text: str, + probe_text: str, + choice_texts: List[str], + predicted_outcomes: Optional[List[str]] = None, + generate_reasoning: bool = True, + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6, + kdma_template_file: str = 'templates/kdma.md', + kdma_descriptions_file: str = 'templates/bbn_kdma_descriptions.md') -> Union[List[Dict[str, float]], Tuple[List[Dict[str, float]], List[Dict[str, str]]]]: + """ + Predicts KDMA scores each choice text under the given scenario and probe. + + :param scenario_text: Scenario text. + :param probe_text: Probe text. + :param choice_texts: Choices text. + :param predicted_outcomes: Predicted outcomes. + :param generate_reasoning: Flag to generate reasoning. + :param log_file: Optional log file. + :param max_new_tokens: Maximum number of new tokens to generate. + :param temperature: Temperature for sampling. + :param kdma_template_file: Template file for KDMA prediction. + :param kdma_descriptions_file: Template file for KDMA descriptions. + :return: KDMA predictions. If generate_reasoning is True, return predictions and reasonings. + """ choice_ids = [f'choice_{i}' for i in range(len(choice_texts))] substitutions = [] info = [] kdma_descriptions = extract_kdma_description(kdma_descriptions_file) if predicted_outcomes is None: predicted_outcomes = [None] * len(choice_texts) + for choice_id, choice, outcome in zip(choice_ids, choice_texts, predicted_outcomes): for kdma, kdma_description in kdma_descriptions.items(): substitution = { @@ -66,7 +90,13 @@ def predict_kdma_scores( substitutions.append(substitution) info.append((choice_id, kdma)) - def parse_kdma_score_response(response): + def parse_kdma_score_response(response: str) -> Dict[str, Union[float, str]]: + """ + Parses KDMA score response. + + :param response: Response to parse. + :return: Dictionary with KDMA score and reasoning if generate_reasoning. + """ if generate_reasoning: start_idx = response.find('{') end_idx = response.rfind('}') @@ -84,7 +114,6 @@ def parse_kdma_score_response(response): response_json = { 'score': float(response[response.find(char):]) } - return response_json generations = self.generate_from_template( @@ -109,7 +138,6 @@ def parse_kdma_score_response(response): if generate_reasoning: choice_reasonings[kdma] = generation['reasoning'] - predicted_kdmas = [ predicted_kdmas[choice_id] @@ -124,64 +152,4 @@ def parse_kdma_score_response(response): if generate_reasoning: return predicted_kdmas, reasonings else: - return predicted_kdmas - - def make_aligned_descision( - self, - scenario, - probe, - choices, - target_kdmas, - alignment_fn, - predict_outcomes=True, - generate_reasoning=True, - kdma_descriptions_file='templates/bbn_kdma_descriptions.md', - outcome_template_file='templates/predict_outcomes.md', - kdma_template_file='templates/predict_kdma_scores_reasoning.md', - ): - # Generate the outcomes for each choice - outcomes = None - if predict_outcomes: - outcomes = self.predict_outcomes( - scenario, - probe, - choices, - outcome_template_file=outcome_template_file - ) - - assert len(choices) == len(outcomes), 'Unexpected state: number of choices and outcomes do not match' - - # Get the scores and reasonings for each choice - predicted_kdma_scores = self.predict_kdma_scores( - scenario, - probe, - choices, - outcomes=outcomes, - kdma_template_file=kdma_template_file, - kdma_descriptions_file=kdma_descriptions_file - ) - - if generate_reasoning: - scores, reasonings = predicted_kdma_scores - else: - scores = predicted_kdma_scores - - assert len(choices) == len(scores), 'Unexpected state: number of choices and scores do not match' - - # Compute the similarity score for each choice - alignment_scores = [] - for score in scores: - alignment_scores.append(alignment_fn(target_kdmas, score)) - - max_idx = alignment_scores.index(max(alignment_scores)) - - justification = { - 'choice': choices[max_idx], - 'outcome': outcomes[max_idx], - 'kdma_scores': scores[max_idx], - } - - if generate_reasoning: - justification['kdma_reasonings'] = reasonings[max_idx] - - return max_idx, justification \ No newline at end of file + return predicted_kdmas \ No newline at end of file diff --git a/align_system/language_model_lib/test_chat_language_model.py b/align_system/language_model_lib/test_chat_language_model.py index c73d9292..004d5abb 100644 --- a/align_system/language_model_lib/test_chat_language_model.py +++ b/align_system/language_model_lib/test_chat_language_model.py @@ -1,6 +1,6 @@ import pytest -from chat_langauge_model import ChatLanguageModel +from align_system.language_model_lib.chat_language_model import ChatLanguageModel MODEL_TO_TEST = 'meta-llama/Llama-2-7b-chat-hf' diff --git a/align_system/language_model_lib/test_language_model.py b/align_system/language_model_lib/test_language_model.py index d22cf99e..41e4f4b5 100644 --- a/align_system/language_model_lib/test_language_model.py +++ b/align_system/language_model_lib/test_language_model.py @@ -1,7 +1,7 @@ import pytest import torch -from language_model import LanguageModel +from align_system.language_model_lib.language_model import LanguageModel MODEL_TO_TEST = 'gpt2' # Use a smaller model for testing diff --git a/align_system/language_model_lib/util.py b/align_system/language_model_lib/util.py index 079ebcf8..e5c97f4f 100644 --- a/align_system/language_model_lib/util.py +++ b/align_system/language_model_lib/util.py @@ -1,6 +1,14 @@ import re +from typing import List, Dict, Union -def dialog_from_string(dialog_string): +def dialog_from_string(dialog_string: str) -> List[Dict[str, str]]: + """ + Transforms the dialog in string format to a list of dictionary format. + + :param dialog_string: Dialog in string format. + :return: Dialog in the list of dictionary format. + """ + # Dictionary to map string markers to role names dialog_markers = { '=== system': 'system', '=== user': 'user', @@ -8,19 +16,21 @@ def dialog_from_string(dialog_string): } dialog = [] lines = dialog_string.split('\n') + current_role = '' current_content = '' for line in lines: - if line.strip() in dialog_markers: - if current_role and current_content: + if line.strip() in dialog_markers: # If a line indicates a role change + if current_role and current_content: # Save the previous role's dialog dialog.append({ 'role': current_role, 'content': current_content.strip() }) - current_role = dialog_markers[line.strip()] + current_role = dialog_markers[line.strip()] # Set the new role current_content = '' - else: + else: # Continue appending content if the role hasn't changed current_content += f'{line}\n' + # Append the last piece of dialog if current_role and current_content: dialog.append({ 'role': current_role, @@ -28,18 +38,31 @@ def dialog_from_string(dialog_string): }) return dialog -def read_file(file_path): +def read_file(file_path: str) -> str: + """ + Reads a file and returns its content. + + :param file_path: Path to the file to read. + :return: The content of the file. + """ with open(file_path, 'r') as f: return f.read() -def format_template(template, **substitutions): +def format_template(template: str, **substitutions: str) -> str: + """ + Replaces placeholders in a template with provided substitutions. + + :param template: The template with placeholders indicated as {{placeholder}}. + :param substitutions: The substitutions to replace in the template. + :return: The template with all placeholders substituted. + """ for key, value in substitutions.items(): key = '{{%s}}' % key if not key in template: raise Exception(f'Could not find key {key} in template') template = template.replace(key, value) - # ensure there are no strings sorrounded by {{ }} + # ensure there are no strings surrounded by {{ }} matches = re.findall(r'{{.*?}}', template) # if there are any matches, raise an exception if len(matches) > 0: @@ -48,19 +71,25 @@ def format_template(template, **substitutions): return template -def extract_kdma_description(descriptions_file): +def extract_kdma_description(descriptions_file: str) -> Dict[str, str]: + """ + Extracts KDMA description from the file. + + :param descriptions_file: File with KDMA descriptions. + :return: Dictionary of KDMA descriptions. + """ kdma_dict = {} kdma_name = None kdma_description = '' with open(descriptions_file, 'r') as f: for line in f: - if line.startswith('#'): - if kdma_name is not None: + if line.startswith('#'): # The line is a KDMA tag + if kdma_name is not None: # Save the previous KDMA's tag and description kdma_dict[kdma_name] = kdma_description.strip() - kdma_name = line[1:].strip() + kdma_name = line[1:].strip() # The new KDMA tag kdma_description = '' - else: + else: # The line is a part of the KDMA's description kdma_description += line if kdma_name is not None and kdma_name not in kdma_dict: @@ -69,9 +98,14 @@ def extract_kdma_description(descriptions_file): return kdma_dict -def dialog_to_string(dialog): - output = '' +def dialog_to_string(dialog: List[Dict[str, str]]) -> str: + """ + Transforms the dialog in list of dictionary to string format. + :param dialog: Dialog in list of dictionary format. + :return: Dialog in string format. + """ + output = '' for dialog_piece in dialog: role = dialog_piece['role'] content = dialog_piece['content'] From a0b9fd355d504598930a89bb397520f291894860 Mon Sep 17 00:00:00 2001 From: Bill Ray Date: Fri, 20 Oct 2023 14:58:13 -0400 Subject: [PATCH 3/3] reorganized language model lib code --- align_system/algorithms/lib/__init__.py | 15 ++++ align_system/algorithms/lib/chat/__init__.py | 0 .../lib/chat}/chat_language_model.py | 8 +-- .../lib/chat}/dialog_tokenizer.py | 0 .../lib}/language_model.py | 0 .../lib/templates/bbn_kdma_descriptions.yml | 19 +++++ .../algorithms/lib/templates/pred_kdma_O.txt | 30 ++++++++ .../algorithms/lib/templates/pred_kdma_RO.txt | 31 ++++++++ .../algorithms/lib/templates/pred_outcome.txt | 14 ++++ .../lib}/util.py | 72 +++++++------------ .../llama_2_kdma_predicting_adm.py | 25 ++++--- align_system/language_model_lib/__init__.py | 15 ---- .../test_chat_language_model.py | 2 +- .../test_language_model.py | 2 +- 14 files changed, 156 insertions(+), 77 deletions(-) create mode 100644 align_system/algorithms/lib/__init__.py create mode 100644 align_system/algorithms/lib/chat/__init__.py rename align_system/{language_model_lib => algorithms/lib/chat}/chat_language_model.py (95%) rename align_system/{language_model_lib => algorithms/lib/chat}/dialog_tokenizer.py (100%) rename align_system/{language_model_lib => algorithms/lib}/language_model.py (100%) create mode 100644 align_system/algorithms/lib/templates/bbn_kdma_descriptions.yml create mode 100644 align_system/algorithms/lib/templates/pred_kdma_O.txt create mode 100644 align_system/algorithms/lib/templates/pred_kdma_RO.txt create mode 100644 align_system/algorithms/lib/templates/pred_outcome.txt rename align_system/{language_model_lib => algorithms/lib}/util.py (67%) rename align_system/{language_model_lib => algorithms}/llama_2_kdma_predicting_adm.py (86%) delete mode 100644 align_system/language_model_lib/__init__.py rename align_system/{language_model_lib => tests}/test_chat_language_model.py (93%) rename align_system/{language_model_lib => tests}/test_language_model.py (94%) diff --git a/align_system/algorithms/lib/__init__.py b/align_system/algorithms/lib/__init__.py new file mode 100644 index 00000000..84f2f138 --- /dev/null +++ b/align_system/algorithms/lib/__init__.py @@ -0,0 +1,15 @@ +from importlib import reload + +def reload_all(): + # Useful function for developing in an interactive environment without having to restart the kernel + + from align_system.algorithms.lib import util + from align_system.algorithms.lib import language_model as lm + from align_system.algorithms.lib.chat import dialog_tokenizer as dt + from align_system.algorithms.lib.chat import chat_language_model as clm + from align_system.algorithms import llama_2_kdma_predicting_adm as kpa + + + # Reload in the correct order + for module in [util, lm, dt, clm, kpa]: + reload(module) diff --git a/align_system/algorithms/lib/chat/__init__.py b/align_system/algorithms/lib/chat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/align_system/language_model_lib/chat_language_model.py b/align_system/algorithms/lib/chat/chat_language_model.py similarity index 95% rename from align_system/language_model_lib/chat_language_model.py rename to align_system/algorithms/lib/chat/chat_language_model.py index 0d8093aa..6326f4dc 100644 --- a/align_system/language_model_lib/chat_language_model.py +++ b/align_system/algorithms/lib/chat/chat_language_model.py @@ -1,8 +1,8 @@ from typing import List, Dict, Optional, Callable, Union, TextIO -from align_system.language_model_lib.language_model import LanguageModel -from align_system.language_model_lib.dialog_tokenizer import dialog_tokenizers -from align_system.language_model_lib.util import read_file, format_template, dialog_from_string, dialog_to_string +from align_system.algorithms.lib.language_model import LanguageModel +from align_system.algorithms.lib.chat.dialog_tokenizer import dialog_tokenizers +from align_system.algorithms.lib.util import read_template, format_template, dialog_from_string, dialog_to_string class ChatLanguageModel(LanguageModel): @@ -115,7 +115,7 @@ def generate_from_template( # Create a dialogue for each template/substitution pair dialogs = { - i: dialog_from_string(format_template(read_file(template_file), **substitutions)) + i: dialog_from_string(format_template(read_template(template_file), **substitutions)) for i, (template_file, substitutions) in enumerate(zip(template_files, substitution_dicts)) } diff --git a/align_system/language_model_lib/dialog_tokenizer.py b/align_system/algorithms/lib/chat/dialog_tokenizer.py similarity index 100% rename from align_system/language_model_lib/dialog_tokenizer.py rename to align_system/algorithms/lib/chat/dialog_tokenizer.py diff --git a/align_system/language_model_lib/language_model.py b/align_system/algorithms/lib/language_model.py similarity index 100% rename from align_system/language_model_lib/language_model.py rename to align_system/algorithms/lib/language_model.py diff --git a/align_system/algorithms/lib/templates/bbn_kdma_descriptions.yml b/align_system/algorithms/lib/templates/bbn_kdma_descriptions.yml new file mode 100644 index 00000000..1b935111 --- /dev/null +++ b/align_system/algorithms/lib/templates/bbn_kdma_descriptions.yml @@ -0,0 +1,19 @@ +basic_knowledge: + name: Basic knowledge + description: Basic knowledge means being able to apply domain knowledge in typical situations. For this program, domain knowledge focuses on initial triage, tactical combat casualty care, and field care. Someone with high basic knowledge will be able to assess and prioritize which injuries to treat immediately, how to continue monitoring and assessing casualties, and how to prioritize and prepare casualties for evacuation. Someone with low basic knowledge might treat less urgent injuries first, might not apply an appropriate treatment, or might freeze up and not take action. + +risk_aversion: + name: Risk aversion + description: Risk aversion is a tendency to avoid uncertainty and prefer taking actions whose expected outcomes have a lower range of variation. A person with high risk aversion may prefer an action that has somewhat lower total expected value when it also has less difference between the best and worst expected outcome. A person with low risk aversion may prefer an action that has a superior best expected outcome, even if the total expected value is lower. We think this might manifest differently in the military medical domain than in the abstract decision making context. We suspect that in a military medical domain, lower risk choices may be decisions that keep open multiple possible paths, while high risk choices may be decisions that commit to a single possible path. + +protocol_focus: + name: Protocol focus + description: Protocol focus is the tendency to prioritize following a protocol, guideline, rule, or process, instead of considering specific context factors as reasons to make exceptions to the protocol or rule. A high protocol focus person will stick to the rules or guidelines, even when it seems like that may waste time, waste effort, or cause unhappiness. A low protocol focus person will make exceptions or break rules when they believe that the rule was not intended to apply to that particular situation, or when they believe that following the rule will produce an inferior result. + +fairness: + name: Fairness + description: Fairness is treating different people in the same way, such that no person is valued more based on their group membership, identity, or prior actions. This can manifest as favoring those who you have a positive personal relationship with, those who are part of a group you also belong to, those who have high social status, or those who were injured while being either virtuous or innocent. A person with high fairness will provide equal assistance to everyone who needs it. A person with low fairness will prioritize providing assistance to those that they care more about, who are more similar to them, or who may be more useful in some way. The concept of fairness may also describe a preference for equal outcomes versus equal assistance. + +utilitarianism: + name: Utilitarianism + description: In this context, utilitarianism is the priority placed on saving the most people or contributing to the best medical outcome for the most people. A person with high utilitarianism will try to save the most people, which under conditions of limited resources may mean withholding or rationing care to patients for whom treatment has a low probability of improving outcome. A person with low utilitarianism may instead allocate limited resources based on other considerations, such as their personal feelings toward patients or trying to demonstrate other values or virtues (such as kindness, fairness, respect, or loyalty). The priority placed on expected utility of the overall outcome. High utilitarianism also suggests a willingness to harm some individuals, through both omission and commission, as part of a tradeoff to increase the overall benefit. It may also consider the likely trajectory of a patient post-treatment and thus deprioritize treating those who are at high risk of dying soon due to other factors. \ No newline at end of file diff --git a/align_system/algorithms/lib/templates/pred_kdma_O.txt b/align_system/algorithms/lib/templates/pred_kdma_O.txt new file mode 100644 index 00000000..e7742b2e --- /dev/null +++ b/align_system/algorithms/lib/templates/pred_kdma_O.txt @@ -0,0 +1,30 @@ +=== system +Given the following scenario, response, and predicted outcome your job is to asses, on a scale from 0 to 10, how well the response characterizes the decision making behavior of someone who acting with high {{kdma}}. Provide one sentence of reasoning. + +{{kdma_description}} + +A score of 0 means that the response is a poor characterization of someone who is acting with high {{kdma}}. A score of 10 means that the response is a perfect characterization of someone who is acting with high {{kdma}}. + +Ensure your response strictly adheres to following JSON format +{ + "score": int score +} +Provide no other information or text in your response. Your response ends when you reach the end of the JSON object. +=== user +Consider the following scenario: +``` +{{scenario}} +{{probe}} +``` +Review this response to the scenario and consider the corresponding predicted outcome: +``` +Response: +{{choice}} + +Predicted Outcome: +{{outcome}} +``` +Generate a JSON object that indicates how well the response characterizes the decision making behavior of someone who is acting with high {{kdma}}. Your response ends when you reach the end of the JSON object. +=== assistant +{ + "score": \ No newline at end of file diff --git a/align_system/algorithms/lib/templates/pred_kdma_RO.txt b/align_system/algorithms/lib/templates/pred_kdma_RO.txt new file mode 100644 index 00000000..5893afec --- /dev/null +++ b/align_system/algorithms/lib/templates/pred_kdma_RO.txt @@ -0,0 +1,31 @@ +=== system +Given the following scenario, response, and predicted outcome your job is to asses, on a scale from 0 to 10, how well the response characterizes the decision making behavior of someone who acting with high {{kdma}}. Provide one sentence of reasoning. + +{{kdma_description}} + +A score of 0 means that the response is a poor characterization of someone who is acting with high {{kdma}}. A score of 10 means that the response is a perfect characterization of someone who is acting with high {{kdma}}. + +Ensure your response strictly adheres to following JSON format +{ + "reasoning": "one-sentence reasoning explaining your score", + "score": int score +} +Provide no other information or text in your response. Your response ends when you reach the end of the JSON object. +=== user +Consider the following scenario: +``` +{{scenario}} +{{probe}} +``` +Review this response to the scenario and consider the corresponding predicted outcome: +``` +Response: +{{choice}} + +Predicted Outcome: +{{outcome}} +``` +Generate a JSON object that provides one sentence of reasoning and indicates how well the response characterizes the decision making behavior of someone who is acting with high {{kdma}}. Your response ends when you reach the end of the JSON object. +=== assistant +{ + "reasoning": " \ No newline at end of file diff --git a/align_system/algorithms/lib/templates/pred_outcome.txt b/align_system/algorithms/lib/templates/pred_outcome.txt new file mode 100644 index 00000000..9aea5572 --- /dev/null +++ b/align_system/algorithms/lib/templates/pred_outcome.txt @@ -0,0 +1,14 @@ +=== system +In a clear and consice way describe the outcomes of answering the given question with the given response. Do not provide additional advice for the situation. Do not indicate wether or not you agree with the response. Simply analyze the outcomes of the response. Limit your analysis to a single sentence. +=== user +Consider the following scenario: +``` +{{scenario}} +{{probe}} +``` +What would happen if you chose to answer the question with the following response? +``` +{{choice}} +``` +=== assistant +Outcome: \ No newline at end of file diff --git a/align_system/language_model_lib/util.py b/align_system/algorithms/lib/util.py similarity index 67% rename from align_system/language_model_lib/util.py rename to align_system/algorithms/lib/util.py index e5c97f4f..ee2ff783 100644 --- a/align_system/language_model_lib/util.py +++ b/align_system/algorithms/lib/util.py @@ -1,5 +1,7 @@ import re -from typing import List, Dict, Union +import os +from typing import List, Dict + def dialog_from_string(dialog_string: str) -> List[Dict[str, str]]: """ @@ -38,15 +40,23 @@ def dialog_from_string(dialog_string: str) -> List[Dict[str, str]]: }) return dialog -def read_file(file_path: str) -> str: + +def dialog_to_string(dialog: List[Dict[str, str]]) -> str: """ - Reads a file and returns its content. + Transforms the dialog in list of dictionary to string format. - :param file_path: Path to the file to read. - :return: The content of the file. + :param dialog: Dialog in list of dictionary format. + :return: Dialog in string format. """ - with open(file_path, 'r') as f: - return f.read() + output = '' + for dialog_piece in dialog: + role = dialog_piece['role'] + content = dialog_piece['content'] + output += f"=== {role}\n" + output += f"{content}\n" + + return output + def format_template(template: str, **substitutions: str) -> str: """ @@ -71,45 +81,13 @@ def format_template(template: str, **substitutions: str) -> str: return template -def extract_kdma_description(descriptions_file: str) -> Dict[str, str]: - """ - Extracts KDMA description from the file. - - :param descriptions_file: File with KDMA descriptions. - :return: Dictionary of KDMA descriptions. - """ - kdma_dict = {} - kdma_name = None - kdma_description = '' - - with open(descriptions_file, 'r') as f: - for line in f: - if line.startswith('#'): # The line is a KDMA tag - if kdma_name is not None: # Save the previous KDMA's tag and description - kdma_dict[kdma_name] = kdma_description.strip() - kdma_name = line[1:].strip() # The new KDMA tag - kdma_description = '' - else: # The line is a part of the KDMA's description - kdma_description += line - - if kdma_name is not None and kdma_name not in kdma_dict: - kdma_dict[kdma_name] = kdma_description.strip() - - return kdma_dict - - -def dialog_to_string(dialog: List[Dict[str, str]]) -> str: - """ - Transforms the dialog in list of dictionary to string format. +def read_template(template_file_name: str, template_dir='templates') -> str: + current_directory = os.path.dirname(os.path.abspath(__file__)) + full_path = os.path.join(current_directory, template_dir, template_file_name) + + with open(full_path, 'r') as template_file: + template = template_file.read() + + return template - :param dialog: Dialog in list of dictionary format. - :return: Dialog in string format. - """ - output = '' - for dialog_piece in dialog: - role = dialog_piece['role'] - content = dialog_piece['content'] - output += f"=== {role}\n" - output += f"{content}\n" - return output \ No newline at end of file diff --git a/align_system/language_model_lib/llama_2_kdma_predicting_adm.py b/align_system/algorithms/llama_2_kdma_predicting_adm.py similarity index 86% rename from align_system/language_model_lib/llama_2_kdma_predicting_adm.py rename to align_system/algorithms/llama_2_kdma_predicting_adm.py index 24fa4350..5464450c 100644 --- a/align_system/language_model_lib/llama_2_kdma_predicting_adm.py +++ b/align_system/algorithms/llama_2_kdma_predicting_adm.py @@ -1,7 +1,8 @@ import json +import yaml +import os from typing import Union, List, Dict, Tuple, Optional, TextIO -from align_system.language_model_lib.chat_language_model import ChatLanguageModel -from align_system.language_model_lib.util import extract_kdma_description +from align_system.algorithms.lib.chat.chat_language_model import ChatLanguageModel class Llama2KDMAPredictingADM(ChatLanguageModel): @@ -12,7 +13,7 @@ def predict_outcomes(self, log_file: Optional[TextIO] = None, max_tokens: int = 512, temperature: float = 0.6, - outcome_template_file: str = 'templates/predict_outcomes.md') -> List[str]: + outcome_template_file: str = 'pred_outcome.txt') -> List[str]: """ Predicts outcomes for given scenario, probe and choices. @@ -50,8 +51,8 @@ def predict_kdma_scores(self, log_file: Optional[TextIO] = None, max_new_tokens: int = 512, temperature: float = 0.6, - kdma_template_file: str = 'templates/kdma.md', - kdma_descriptions_file: str = 'templates/bbn_kdma_descriptions.md') -> Union[List[Dict[str, float]], Tuple[List[Dict[str, float]], List[Dict[str, str]]]]: + kdma_template_file: str = 'pred_kdma_RO.txt', + kdma_descriptions_file: str = 'lib/templates/bbn_kdma_descriptions.yml') -> Union[List[Dict[str, float]], Tuple[List[Dict[str, float]], List[Dict[str, str]]]]: """ Predicts KDMA scores each choice text under the given scenario and probe. @@ -70,15 +71,21 @@ def predict_kdma_scores(self, choice_ids = [f'choice_{i}' for i in range(len(choice_texts))] substitutions = [] info = [] - kdma_descriptions = extract_kdma_description(kdma_descriptions_file) + + relative_dir = os.path.dirname(__file__) + kdma_descriptions_file_path = os.path.join(relative_dir, kdma_descriptions_file) + + with open(kdma_descriptions_file_path, 'r') as f: + kdma_descriptions = yaml.load(f, Loader=yaml.FullLoader) + if predicted_outcomes is None: predicted_outcomes = [None] * len(choice_texts) for choice_id, choice, outcome in zip(choice_ids, choice_texts, predicted_outcomes): - for kdma, kdma_description in kdma_descriptions.items(): + for kdma, kdma_info in kdma_descriptions.items(): substitution = { - 'kdma': kdma, - 'kdma_description': kdma_description, + 'kdma': kdma_info['name'], + 'kdma_description': kdma_info['description'], 'scenario': scenario_text, 'probe': probe_text, 'choice': choice, diff --git a/align_system/language_model_lib/__init__.py b/align_system/language_model_lib/__init__.py deleted file mode 100644 index 0f451679..00000000 --- a/align_system/language_model_lib/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from importlib import reload - -def reload_all(): - # Import the modules inside this function to ensure they're available for reloading - - from . import util - from . import language_model as lm - from . import dialog_tokenizer as dt - from . import chat_langauge_model as clm - from . import llama_2_kdma_predicting_adm as kpa - - - # Reload in the correct order - for module in [util, lm, dt, clm, kpa]: - reload(module) diff --git a/align_system/language_model_lib/test_chat_language_model.py b/align_system/tests/test_chat_language_model.py similarity index 93% rename from align_system/language_model_lib/test_chat_language_model.py rename to align_system/tests/test_chat_language_model.py index 004d5abb..aa81e701 100644 --- a/align_system/language_model_lib/test_chat_language_model.py +++ b/align_system/tests/test_chat_language_model.py @@ -1,6 +1,6 @@ import pytest -from align_system.language_model_lib.chat_language_model import ChatLanguageModel +from align_system.algorithms.lib.chat.chat_language_model import ChatLanguageModel MODEL_TO_TEST = 'meta-llama/Llama-2-7b-chat-hf' diff --git a/align_system/language_model_lib/test_language_model.py b/align_system/tests/test_language_model.py similarity index 94% rename from align_system/language_model_lib/test_language_model.py rename to align_system/tests/test_language_model.py index 41e4f4b5..df068b2b 100644 --- a/align_system/language_model_lib/test_language_model.py +++ b/align_system/tests/test_language_model.py @@ -1,7 +1,7 @@ import pytest import torch -from align_system.language_model_lib.language_model import LanguageModel +from align_system.algorithms.lib.language_model import LanguageModel MODEL_TO_TEST = 'gpt2' # Use a smaller model for testing