Added ability to finetune models

This commit is contained in:
2023-04-21 11:22:47 +00:00
parent 151eefa1de
commit 3e9e2196e9
7 changed files with 36 additions and 20 deletions

View File

@@ -30,12 +30,15 @@ def seed_worker(worker_id):
generator = torch.Generator()
generator.manual_seed(seed)
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def parse_args():
parser = argparse.ArgumentParser(description='Export embeddings')
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
parser.add_argument('--output', type=str, default=None, help='Path to output')
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
parser.add_argument('--format', type=str, default='csv', help='Format of the output file (csv, json)')
args = parser.parse_args()
return args
@@ -85,7 +88,9 @@ with torch.no_grad():
df = pd.read_csv(args.dataset)
df["embeddings"] = embeddings
df = df[['embeddings', 'label_name', 'labels']]
df['embeddings2'] = df['embeddings'].apply(lambda x: x.tolist())
df['embeddings'] = df['embeddings'].apply(lambda x: x.tolist()[0])
df.to_csv(args.output, index=False)
if args.format == 'json':
df.to_json(args.output, orient='records')
elif args.format == 'csv':
df.to_csv(args.output, index=False)