Added ability to finetune models
This commit is contained in:
BIN
checkpoints/checkpoint_embed_3006.pth
Normal file
BIN
checkpoints/checkpoint_embed_3006.pth
Normal file
Binary file not shown.
@@ -30,12 +30,15 @@ def seed_worker(worker_id):
|
|||||||
|
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
|
import os
|
||||||
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Export embeddings')
|
parser = argparse.ArgumentParser(description='Export embeddings')
|
||||||
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
|
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
|
||||||
parser.add_argument('--output', type=str, default=None, help='Path to output')
|
parser.add_argument('--output', type=str, default=None, help='Path to output')
|
||||||
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
|
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
|
||||||
|
parser.add_argument('--format', type=str, default='csv', help='Format of the output file (csv, json)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@@ -85,7 +88,9 @@ with torch.no_grad():
|
|||||||
df = pd.read_csv(args.dataset)
|
df = pd.read_csv(args.dataset)
|
||||||
df["embeddings"] = embeddings
|
df["embeddings"] = embeddings
|
||||||
df = df[['embeddings', 'label_name', 'labels']]
|
df = df[['embeddings', 'label_name', 'labels']]
|
||||||
df['embeddings2'] = df['embeddings'].apply(lambda x: x.tolist())
|
df['embeddings'] = df['embeddings'].apply(lambda x: x.tolist()[0])
|
||||||
|
|
||||||
|
if args.format == 'json':
|
||||||
df.to_csv(args.output, index=False)
|
df.to_json(args.output, orient='records')
|
||||||
|
elif args.format == 'csv':
|
||||||
|
df.to_csv(args.output, index=False)
|
||||||
@@ -12,10 +12,10 @@ output=32
|
|||||||
# load PyTorch model from .pth file
|
# load PyTorch model from .pth file
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
# if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
# device = torch.device("cuda")
|
||||||
|
|
||||||
CHECKPOINT_PATH = "out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth"
|
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
|
||||||
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
||||||
|
|
||||||
model = SPOTER_EMBEDDINGS(
|
model = SPOTER_EMBEDDINGS(
|
||||||
@@ -30,7 +30,10 @@ model.eval()
|
|||||||
model_export = "onnx"
|
model_export = "onnx"
|
||||||
if model_export == "coreml":
|
if model_export == "coreml":
|
||||||
dummy_input = torch.randn(1, 10, 54, 2)
|
dummy_input = torch.randn(1, 10, 54, 2)
|
||||||
|
# set device for dummy input
|
||||||
|
dummy_input = dummy_input.to(device)
|
||||||
traced_model = torch.jit.trace(model, dummy_input)
|
traced_model = torch.jit.trace(model, dummy_input)
|
||||||
|
|
||||||
out = traced_model(dummy_input)
|
out = traced_model(dummy_input)
|
||||||
import coremltools as ct
|
import coremltools as ct
|
||||||
|
|
||||||
@@ -41,10 +44,12 @@ if model_export == "coreml":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save Core ML model
|
# Save Core ML model
|
||||||
coreml_model.save("models/" + model_name + ".mlmodel")
|
coreml_model.save("out-models/" + model_name + ".mlmodel")
|
||||||
else:
|
else:
|
||||||
# create dummy input tensor
|
# create dummy input tensor
|
||||||
dummy_input = torch.randn(1, 10, 54, 2)
|
dummy_input = torch.randn(1, 10, 54, 2)
|
||||||
|
# set device for dummy input
|
||||||
|
dummy_input = dummy_input.to(device)
|
||||||
|
|
||||||
# export model to ONNX format
|
# export model to ONNX format
|
||||||
output_file = 'models/' + model_name + '.onnx'
|
output_file = 'models/' + model_name + '.onnx'
|
||||||
@@ -52,7 +57,7 @@ else:
|
|||||||
|
|
||||||
torch.onnx.export(model, # model being run
|
torch.onnx.export(model, # model being run
|
||||||
dummy_input, # model input (or a tuple for multiple inputs)
|
dummy_input, # model input (or a tuple for multiple inputs)
|
||||||
'output-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
|
'out-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
|
||||||
export_params=True, # store the trained parameter weights inside the model file
|
export_params=True, # store the trained parameter weights inside the model file
|
||||||
opset_version=9, # the ONNX version to export the model to
|
opset_version=9, # the ONNX version to export the model to
|
||||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||||
|
|||||||
2
train.py
2
train.py
@@ -77,7 +77,7 @@ def train(args, tracker: Tracker):
|
|||||||
if not args.classification_model:
|
if not args.classification_model:
|
||||||
# if finetune, load the weights from the classification model
|
# if finetune, load the weights from the classification model
|
||||||
if args.finetune:
|
if args.finetune:
|
||||||
checkpoint = torch.load(args.checkpoint, map_location=device)
|
checkpoint = torch.load(args.checkpoint_path, map_location=device)
|
||||||
|
|
||||||
slrt_model = SPOTER_EMBEDDINGS(
|
slrt_model = SPOTER_EMBEDDINGS(
|
||||||
features=checkpoint["config_args"].vector_length,
|
features=checkpoint["config_args"].vector_length,
|
||||||
|
|||||||
24
train.sh
24
train.sh
@@ -1,21 +1,23 @@
|
|||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
python3 -m train \
|
python3 -m train \
|
||||||
--save_checkpoints_every 10 \
|
--save_checkpoints_every 1 \
|
||||||
--experiment_name "wlasl" \
|
--experiment_name "Finetune Basic Signs" \
|
||||||
--epochs 600 \
|
--epochs 100 \
|
||||||
--optimizer "SGD" \
|
--optimizer "ADAM" \
|
||||||
--lr 0.001 \
|
--lr 0.00001 \
|
||||||
--batch_size 16 \
|
--batch_size 16 \
|
||||||
--dataset_name "WLASL" \
|
--dataset_name "BasicSigns" \
|
||||||
--training_set_path "WLASL100_train.csv" \
|
--training_set_path "train.csv" \
|
||||||
--validation_set_path "WLASL100_val.csv" \
|
--validation_set_path "val.csv" \
|
||||||
--vector_length 32 \
|
--vector_length 32 \
|
||||||
--epoch_iters -1 \
|
--epoch_iters -1 \
|
||||||
--scheduler_factor 0.2 \
|
--scheduler_factor 0.05 \
|
||||||
--hard_triplet_mining "in_batch" \
|
--hard_triplet_mining "None" \
|
||||||
--filter_easy_triplets \
|
--filter_easy_triplets \
|
||||||
--triplet_loss_margin 2 \
|
--triplet_loss_margin 2 \
|
||||||
--dropout 0.2 \
|
--dropout 0.2 \
|
||||||
--tracker=clearml \
|
--tracker=clearml \
|
||||||
--dataset_loader=clearml \
|
--dataset_loader=clearml \
|
||||||
--dataset_project="SpoterEmbedding"
|
--dataset_project="SpoterEmbedding" \
|
||||||
|
--finetune \
|
||||||
|
--checkpoint_path "checkpoints/checkpoint_embed_3006.pth"
|
||||||
@@ -81,4 +81,7 @@ def get_default_args():
|
|||||||
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
|
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
|
||||||
distance threshold of the batch sorter")
|
distance threshold of the batch sorter")
|
||||||
|
|
||||||
|
parser.add_argument("--finetune", action='store_true', default=False, help="Fintune the model")
|
||||||
|
parser.add_argument("--checkpoint_path", type=str, default="")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
@@ -238,6 +238,7 @@ def distance_matrix(keypoints, embeddings, p=2, threshold=1000000):
|
|||||||
f"{kk}-dimensional vectors")
|
f"{kk}-dimensional vectors")
|
||||||
|
|
||||||
if m*n*k <= threshold:
|
if m*n*k <= threshold:
|
||||||
|
print("Using minkowski_distance")
|
||||||
return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p)
|
return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p)
|
||||||
else:
|
else:
|
||||||
result = np.empty((m,n),dtype=float) # FIXME: figure out the best dtype
|
result = np.empty((m,n),dtype=float) # FIXME: figure out the best dtype
|
||||||
|
|||||||
Reference in New Issue
Block a user