Created guide and script to export embeddings

This commit is contained in:
2023-04-14 14:40:05 +00:00
parent 49ced1983d
commit 1f24df1b8f
6 changed files with 255 additions and 150 deletions

View File

@@ -152,19 +152,23 @@ def create(args):
df = pd.concat([df, df_aux], axis=1)
if args.create_new_split:
df_train, df_test = train_test_split(df, test_size=test_size, stratify=df['labels'], random_state=42)
else:
print(df['split'].unique())
df_train = df[(df['split'] == 'train') | (df['split'] == 'val')]
df_test = df[df['split'] == 'test']
print(f'Num classes: {num_classes}')
print(df_train['labels'].value_counts())
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
the --create-new-split flag'
for split, df_split in zip(['train', 'val'],
[df_train, df_test]):
fn_out = op.join(dataset_folder, f'fingerspelling_{split}.csv')
(df_split.reset_index(drop=True)
.applymap(convert_to_str)
.to_csv(fn_out, index=False))
print(f'Num classes: {num_classes}')
print(df_train['labels'].value_counts())
print(df_test['labels'].value_counts())
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
the --create-new-split flag'
for split, df_split in zip(['train', 'val'],
[df_train, df_test]):
fn_out = op.join(dataset_folder, f'{split}.csv')
(df_split.reset_index(drop=True)
.applymap(convert_to_str)
.to_csv(fn_out, index=False))
else:
fn_out = op.join(dataset_folder, 'train.csv')
(df.reset_index(drop=True)
.applymap(convert_to_str)
.to_csv(fn_out, index=False))