Online dict embeddings + updated embedding instructions
This commit is contained in:
@@ -68,6 +68,7 @@ data_loader = DataLoader(
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn_padd,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
#num_workers=0, # Uncomment this line (and comment out next line) if you want to disable multithreading
|
||||
num_workers=multiprocessing.cpu_count(),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator,
|
||||
@@ -88,7 +89,7 @@ with torch.no_grad():
|
||||
df = pd.read_csv(args.dataset)
|
||||
df["embeddings"] = embeddings
|
||||
df = df[['embeddings', 'label_name', 'labels']]
|
||||
df['embeddings'] = df['embeddings'].apply(lambda x: x.tolist()[0])
|
||||
df['embeddings2'] = df['embeddings'].apply(lambda x: x.tolist()[0])
|
||||
|
||||
if args.format == 'json':
|
||||
df.to_json(args.output, orient='records')
|
||||
|
||||
Reference in New Issue
Block a user