Online dict embeddings + updated embedding instructions

This commit is contained in:
RobbeDeWaele
2023-04-23 11:22:32 +02:00
parent 3e9e2196e9
commit 9f5309e878
3 changed files with 95 additions and 3 deletions

View File

@@ -68,6 +68,7 @@ data_loader = DataLoader(
shuffle=False,
collate_fn=collate_fn_padd,
pin_memory=torch.cuda.is_available(),
#num_workers=0, # Uncomment this line (and comment out next line) if you want to disable multithreading
num_workers=multiprocessing.cpu_count(),
worker_init_fn=seed_worker,
generator=generator,
@@ -88,7 +89,7 @@ with torch.no_grad():
df = pd.read_csv(args.dataset)
df["embeddings"] = embeddings
df = df[['embeddings', 'label_name', 'labels']]
df['embeddings'] = df['embeddings'].apply(lambda x: x.tolist()[0])
df['embeddings2'] = df['embeddings'].apply(lambda x: x.tolist()[0])
if args.format == 'json':
df.to_json(args.output, orient='records')