diff --git a/train_network.py b/train_network.py index 17e84e87..5aa8af48 100644 --- a/train_network.py +++ b/train_network.py @@ -135,6 +135,8 @@ def train(args): gc.collect() # prepare network + import sys + sys.path.append(os.path.dirname(__file__)) print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module)