Added ability to finetune models
This commit is contained in:
@@ -30,12 +30,15 @@ def seed_worker(worker_id):
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
import os
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Export embeddings')
|
||||
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
|
||||
parser.add_argument('--output', type=str, default=None, help='Path to output')
|
||||
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
|
||||
parser.add_argument('--format', type=str, default='csv', help='Format of the output file (csv, json)')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -85,7 +88,9 @@ with torch.no_grad():
|
||||
df = pd.read_csv(args.dataset)
|
||||
df["embeddings"] = embeddings
|
||||
df = df[['embeddings', 'label_name', 'labels']]
|
||||
df['embeddings2'] = df['embeddings'].apply(lambda x: x.tolist())
|
||||
df['embeddings'] = df['embeddings'].apply(lambda x: x.tolist()[0])
|
||||
|
||||
|
||||
df.to_csv(args.output, index=False)
|
||||
if args.format == 'json':
|
||||
df.to_json(args.output, orient='records')
|
||||
elif args.format == 'csv':
|
||||
df.to_csv(args.output, index=False)
|
||||
Reference in New Issue
Block a user