Source code for fairdiverse.search.trainer

import os
import yaml


[docs] class SRDTrainer(object): def __init__(self, train_config): """ Initialize post-processing and base models. :param train_config: Your custom config files. """ self.train_config = train_config
[docs] def load_configs(self, dir): """ Loads and merges configuration files for the model, dataset, and evaluation. This function loads multiple YAML configuration files, including the process configuration, dataset-specific settings, model configurations, and evaluation parameters. All configurations are merged, with the highest priority given to the class's own `config` attribute. :param dir: The directory where the main process configuration file is located. :return: A dictionary containing the merged configuration from all files. """ print("start to load dataset config...") with open(os.path.join(self.train_config['task'], "properties", "dataset", self.train_config['dataset'].lower() + ".yaml"), 'r') as f: config = yaml.safe_load(f) config.update({'data_dir': dir}) print("start to load model config...") with open(os.path.join(self.train_config['task'], "properties", "models", self.train_config['model'].lower() + ".yaml"), 'r') as f: model_config = yaml.safe_load(f) config.update(model_config) with open(os.path.join("search", "properties", "evaluation.yaml"), 'r') as f: config.update(yaml.safe_load(f)) config.update(self.train_config) ###train_config has highest rights print("your loading config is:") print(config) return config
[docs] def train(self): """ Training post-processing search model main workflow. """ dir = os.path.join(self.train_config['task'], "processed_dataset", self.train_config['dataset']) config = self.load_configs(dir) if os.path.exists(os.path.join(config['task'], "processed_dataset", config['dataset'], config['model'])) and config['reprocess'] == False: print("Data has been processed, start to load the dataset...") else: print("start to process data...") if os.path.join(config['data_dir'], 'div_query.data') not in config['data_dir']: from .utils.process_dataset import data_process data_process(config) if config['model'].lower() == 'desa': from .utils.process_desa import Process Process(config) elif config['model'].lower() == 'daletor': from .utils.process_daletor import Process Process(config) elif config['model'].lower() == 'xquad': from .utils.process_bm25 import generate_bm25_scores_for_query generate_bm25_scores_for_query(config) elif config['model'].lower() == 'pm2': from .utils.process_bm25 import generate_bm25_scores_for_query generate_bm25_scores_for_query(config) elif config['model'].lower() == 'llm': pass else: raise NotImplementedError(f"Not supported model type: {config['model']}") print("start to load dataset......") self.device = config['device'] if config['mode'] == 'test' and config['best_model_list'] != []: print("start to test the model...") from .post_evaluator import get_global_fullset_metric get_global_fullset_metric(config) elif config['mode'] == 'train': """ For implementing your own supervised methods, you need to first re-write xxx_run and then implement your own model. You can also just copy the two example xxx_run implementation for quick starting. For the unsupervised methods, just implment the model function. """ if config['model'].lower() == 'desa': from .datasets.DESA import DESA_run DESA_run(config) elif config['model'].lower() == 'daletor': from .datasets.DALETOR import DALETOR_run DALETOR_run(config) elif config['model'].lower() == 'xquad': from .postprocessing_model.xQuAD import xQuAD xquad = xQuAD() xquad.rerank(config) elif config['model'].lower() == 'pm2': from .postprocessing_model.PM2 import PM2 pm2 = PM2() pm2.rerank(config) elif config['model'].lower() == 'llm': from .datasets.LLM import llm_run llm_run(config) else: raise NotImplementedError(f"Not supported model type: {config['model']}")