diff --git a/app/recognition/facenet.py b/app/recognition/facenet.py index e6a698d2f5a44ead5c11fe3ef6d5741bac627f61..349a1fc61e9bef6799fe0aa87890bd3cbecda067 100644 --- a/app/recognition/facenet.py +++ b/app/recognition/facenet.py @@ -32,13 +32,15 @@ class FaceNetTorch(FaceRecognition): def create_model(self): # set TORCH_HOME in your os environment self.model = InceptionResnetV1(pretrained='vggface2') - assert torch.cuda.is_available() - self.model.cuda() + if torch.cuda.is_available(): + self.model.cuda() self.model.eval() def get_embeddings(self, input_image: numpy.ndarray): x = self.transform(input_image) - x = x.unsqueeze(0).cuda() + x = x.unsqueeze(0) + if torch.cuda.is_available(): + x = x.cuda() embeddings = self.model(x) embeddings = embeddings.detach().cpu().numpy() return embeddings[0]