Created guide and script to export embeddings
This commit is contained in:
@@ -152,19 +152,23 @@ def create(args):
|
||||
df = pd.concat([df, df_aux], axis=1)
|
||||
if args.create_new_split:
|
||||
df_train, df_test = train_test_split(df, test_size=test_size, stratify=df['labels'], random_state=42)
|
||||
else:
|
||||
print(df['split'].unique())
|
||||
df_train = df[(df['split'] == 'train') | (df['split'] == 'val')]
|
||||
df_test = df[df['split'] == 'test']
|
||||
|
||||
|
||||
print(f'Num classes: {num_classes}')
|
||||
print(df_train['labels'].value_counts())
|
||||
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
|
||||
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
|
||||
the --create-new-split flag'
|
||||
for split, df_split in zip(['train', 'val'],
|
||||
[df_train, df_test]):
|
||||
fn_out = op.join(dataset_folder, f'fingerspelling_{split}.csv')
|
||||
(df_split.reset_index(drop=True)
|
||||
.applymap(convert_to_str)
|
||||
.to_csv(fn_out, index=False))
|
||||
print(f'Num classes: {num_classes}')
|
||||
print(df_train['labels'].value_counts())
|
||||
print(df_test['labels'].value_counts())
|
||||
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
|
||||
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
|
||||
the --create-new-split flag'
|
||||
for split, df_split in zip(['train', 'val'],
|
||||
[df_train, df_test]):
|
||||
fn_out = op.join(dataset_folder, f'{split}.csv')
|
||||
(df_split.reset_index(drop=True)
|
||||
.applymap(convert_to_str)
|
||||
.to_csv(fn_out, index=False))
|
||||
|
||||
else:
|
||||
fn_out = op.join(dataset_folder, 'train.csv')
|
||||
(df.reset_index(drop=True)
|
||||
.applymap(convert_to_str)
|
||||
.to_csv(fn_out, index=False))
|
||||
Reference in New Issue
Block a user