diff --git a/app/api/deepfake.py b/app/api/deepfake.py index e283602502bd21e9252c3969c66a18fbc44d4f04..309a1c797ecba7925c8ef396d963fdc2b0cbd137 100644 --- a/app/api/deepfake.py +++ b/app/api/deepfake.py @@ -1,4 +1,5 @@ import logging +import mimetypes import os from flask import jsonify, request, send_from_directory @@ -29,6 +30,7 @@ def upload_video(): video.video_name = video_name video.status = 0 video.progress = 0 + video.probability = 0. db.session.add(video) db.session.commit() return jsonify(success=True, data=video_name, status_code=200) @@ -57,7 +59,13 @@ def view_video(video_name): data_path = os.path.join(file_utils.VIDEO_FOLDER, video_name) if not os.path.exists(data_path): raise NotFound() - return send_from_directory(file_utils.VIDEO_FOLDER, video_name, mimetype='video/mp4', as_attachment=False) + if video_name.endswith('.mp4'): + mimetype = 'video/mp4' + elif video_name.endswith('.webm'): + mimetype = 'video/webm' + else: + mimetype = 'video/mp4' + return send_from_directory(file_utils.VIDEO_FOLDER, video_name, mimetype=mimetype, as_attachment=False) @bp.route('/deepfake/video/', methods=['DELETE']) @@ -69,7 +77,7 @@ def delete_video(video_id): raise BadRequest('Illegal argument!') file_utils.remove_video(video_entity.video_name) - file_utils.remove_video(f'{video_entity.video_name[:-4]}_mask.mp4') + file_utils.remove_video(f'{video_entity.video_name[:-4]}_mask.webm') db.session.delete(video_entity) db.session.commit() return jsonify(success=True, data=None, status_code=200) diff --git a/app/api/deepfake_detection.py b/app/api/deepfake_detection.py index 65c3b86ad4bb52c0c96b0157f7be5bcda98bf166..4b2475b72d8406901a83407fc91a09771a79f2a0 100644 --- a/app/api/deepfake_detection.py +++ b/app/api/deepfake_detection.py @@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor import cv2 import numpy as np +import pymysql from PIL import Image from flask import jsonify from sqlalchemy import create_engine, update @@ -10,11 +11,10 @@ from sqlalchemy.orm import sessionmaker, scoped_session from werkzeug.exceptions import BadRequest import config -from app import db from app.model.models import DeepfakeVideo from app.recognition import DeepfakeDetectionSIFDNet from app.recognition.face_detector import FacenetDetector -from app.utils import file_utils +from app.utils import file_utils, metrics_utils from app.utils.image_utils import crop_and_resize, ndarray_error_threshold from . import bp @@ -29,81 +29,106 @@ pool = ThreadPoolExecutor(max_workers=1) # loop = asyncio.get_event_loop() FPS = 25 -detector = FacenetDetector(landmarks=False) -deepfake_detection_sifdnet = DeepfakeDetectionSIFDNet.instance() +FACE_DETECTOR = FacenetDetector(landmarks=False) +DEEPFAKE_DETECTOR = DeepfakeDetectionSIFDNet.instance() -def run_detection(video_entity, face_detector, deepfake_detection_model): +def exception_callback(future): + exception = future.exception() + if exception: + LOG.error(exception) + + +def run_detection(video_entity, face_detector, deepfake_detector): session = Session() - session.execute( - update(DeepfakeVideo).where(DeepfakeVideo.id == video_entity.id).values(status=1, progress=0) - ) - session.commit() - # video_entity = session - # video_entity.status = 1 - video_path = str(file_utils.VIDEO_FOLDER / video_entity.video_name) - frames = file_utils.get_video_frames(video_path) - - mask_frames = [] try: + # Init status=1 and progress=0 + session.execute( + update(DeepfakeVideo) + .where(DeepfakeVideo.id == video_entity.id) + .values(status=1, progress=0) + ) + session.commit() + video_path = str(file_utils.VIDEO_FOLDER / video_entity.video_name) + frames = file_utils.get_video_frames(video_path) + + mask_frames = [] + fake_probs = [] + for index, frame in frames.items(): # if index > 50: # break + + # Update execution progress if (index + 1) % FPS == 0: progress = int((index + 1) / len(frames) * 100) session.execute( - update(DeepfakeVideo).where(DeepfakeVideo.id == video_entity.id).values(progress=progress) + update(DeepfakeVideo) + .where(DeepfakeVideo.id == video_entity.id) + .values(progress=progress) ) session.commit() + + # Init mask frame mask_frame = np.zeros_like(frame) + + # Get bounding boxes through face detection resized_image = Image.fromarray(frame) resized_image = resized_image.resize(size=(resized_image.size[0] // 2, resized_image.size[1] // 2)) bboxes = face_detector.detect_faces(resized_image) - # frame = np.asarray(frame) + + # Record fake probability of one frame + frame_fake_prob_sum = 0. for bbox in bboxes: + # Crop face xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox] w = xmax - xmin h = ymax - ymin p_h = h // 3 p_w = w // 3 crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w] - # cv2.imwrite(f"../temp/{index}.png", crop) - y_pred, masks_pred = deepfake_detection_model.get_prediction(crop) - masks_pred = ndarray_error_threshold(masks_pred) - masks_pred = crop_and_resize(masks_pred, crop) - masks_pred = masks_pred * 255 - masks_pred = np.expand_dims(masks_pred, axis=2) - masks_pred = np.tile(masks_pred, (1, 1, 3)) - # masks_pred = np.ones_like(crop) * 255 - mask_frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w] = masks_pred + fake_prob, mask_pred = deepfake_detector.get_predictions(crop) + frame_fake_prob_sum += fake_prob + + # Fill the frame mask with the mask of manipulation region. + mask_pred = ndarray_error_threshold(mask_pred) + mask_pred = crop_and_resize(mask_pred, crop) + mask_pred = mask_pred * 255 + mask_pred = np.expand_dims(mask_pred, axis=2) + mask_pred = np.tile(mask_pred, (1, 1, 3)) + mask_frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w] = mask_pred + fake_probs.append(frame_fake_prob_sum / len(bboxes)) mask_frames.append(mask_frame) + + # Update status=2, progress=100 and probability=? + fake_probs = np.array(fake_probs) + result = metrics_utils.get_video_fake_prob(fake_probs) session.execute( - update(DeepfakeVideo).where(DeepfakeVideo.id == video_entity.id).values(status=2, progress=100) - ) - session.commit() - except Exception as e: - LOG.info(e) - session.execute( - update(DeepfakeVideo).where(DeepfakeVideo.id == video_entity.id).values(status=0, progress=0) + update(DeepfakeVideo) + .where(DeepfakeVideo.id == video_entity.id) + .values(status=2, progress=100, probability=result) ) session.commit() - return - shape = mask_frames[0].shape - mask_video_path = str(file_utils.VIDEO_FOLDER / f'{video_entity.video_name[:-4]}_mask.mp4') - fourcc = cv2.VideoWriter_fourcc(*'avc1') - # fourcc = *'avc1' - try: + # Save mask video + shape = mask_frames[0].shape + mask_video_path = str(file_utils.VIDEO_FOLDER / f'{video_entity.video_name[:-4]}_mask.webm') + fourcc = cv2.VideoWriter_fourcc(*"vp80") video_tracked = cv2.VideoWriter(mask_video_path, fourcc, FPS, (shape[1], shape[0])) for frame in mask_frames: video_tracked.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) video_tracked.release() + except pymysql.err.OperationalError as e: + LOG.info(e) except Exception as e: LOG.info(e) session.execute( - update(DeepfakeVideo).where(DeepfakeVideo.id == video_entity.id).values(status=0, progress=0) + update(DeepfakeVideo) + .where(DeepfakeVideo.id == video_entity.id) + .values(status=3, progress=0, probability=0.) ) session.commit() + return @bp.route('/deepfake/detection/', methods=['GET']) @@ -115,10 +140,15 @@ def deepfake_detection(video_id): if status == 1: raise BadRequest('Video is being analyzed.') running_count = DeepfakeVideo.query.filter_by(status=1).count() - if running_count >= 3: - raise BadRequest('The maximum number of tasks that can be run is 3. Please try again later.') + if running_count >= 1: + raise BadRequest('The maximum number of tasks that can be run is 1. Please try again later.') - pool.submit(run_detection, video_entity, detector, deepfake_detection_sifdnet) + try: + future = pool.submit(run_detection, video_entity, FACE_DETECTOR, DEEPFAKE_DETECTOR) + # future.add_done_callback(exception_callback) + except Exception as e: + LOG.error(e) + return jsonify(success=False, data="执行失败,请稍后重试!", status_code=500) # task = run_detection(video_entity, db.session, DETECTOR, deepfake_detection_sifdnet) # loop.run_in_executor() return jsonify(success=True, data=None, status_code=200) @@ -127,6 +157,6 @@ def deepfake_detection(video_id): @bp.route('/detection/progress/', methods=['GET']) def detection_progress(video_id): detection_result = DeepfakeVideo.query.with_entities( - DeepfakeVideo.status, DeepfakeVideo.progress + DeepfakeVideo.status, DeepfakeVideo.progress, DeepfakeVideo.probability ).filter_by(id=video_id).first() return jsonify(success=True, data=detection_result._asdict(), status_code=200) diff --git a/app/model/models.py b/app/model/models.py index b2fa768982bb6344c91950c505a4ea1cfb7f7779..3debe93e92c47dd449d5ae2de0adc167f85501da 100644 --- a/app/model/models.py +++ b/app/model/models.py @@ -59,9 +59,10 @@ class DeepfakeVideo(db.Model, PaginatedAPIMixin): __tablename__ = 't_deepfake_video' id = db.Column(db.Integer, primary_key=True) - video_name = db.Column(db.String(512), nullable=False) + video_name = db.Column(db.String(256), nullable=False) status = db.Column(db.Integer, nullable=False) progress = db.Column(db.Integer, nullable=False) + probability = db.Column(db.Float, nullable=False) create_date = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) def __repr__(self): @@ -73,6 +74,7 @@ class DeepfakeVideo(db.Model, PaginatedAPIMixin): 'video_name': self.video_name, 'status': self.status, 'progress': self.progress, + 'probability': self.probability, 'create_date': self.create_date } return data diff --git a/app/recognition/anti_spoofing_two_stream_vit.py b/app/recognition/anti_spoofing_two_stream_vit.py index 76b0ff2bb211ba0649fa45da9562d3dffe0891ea..5537ed884c2e97b02791ce69b090f092de39a3b4 100644 --- a/app/recognition/anti_spoofing_two_stream_vit.py +++ b/app/recognition/anti_spoofing_two_stream_vit.py @@ -38,7 +38,7 @@ class AntiSpoofingTwoStreamVit(FaceModel): self.model.cuda() self.model.eval() - def get_embeddings(self, x: numpy.ndarray): + def get_embeddings(self, x: numpy.ndarray) -> numpy.ndarray: pass def get_prediction(self, x: numpy.ndarray, **kwargs) -> numpy.ndarray: @@ -56,3 +56,6 @@ class AntiSpoofingTwoStreamVit(FaceModel): logits = self.model(x, msr) y_preds = torch.argmax(logits, dim=0).detach().cpu().numpy() return y_preds + + def get_predictions(self, x: numpy.ndarray, **kwargs) -> (float, numpy.ndarray): + pass diff --git a/app/recognition/deepfake_detection_sifdnet.py b/app/recognition/deepfake_detection_sifdnet.py index d2eff9b5c2b1c1f0b31154bce5c8a94ff35975d7..a038eab854c3fdcf61a8f81e1591ccd5db84e3c4 100644 --- a/app/recognition/deepfake_detection_sifdnet.py +++ b/app/recognition/deepfake_detection_sifdnet.py @@ -3,6 +3,7 @@ import re import numpy import torch +import torch.nn.functional as F import config from .face_model import FaceModel @@ -38,16 +39,21 @@ class DeepfakeDetectionSIFDNet(FaceModel): self.model.cuda() self.model.eval() - def get_embeddings(self, x: numpy.ndarray): + def get_embeddings(self, x: numpy.ndarray) -> numpy.ndarray: pass - def get_prediction(self, x: numpy.ndarray, **kwargs): + def get_prediction(self, x: numpy.ndarray, **kwargs) -> numpy.ndarray: + pass + + def get_predictions(self, x: numpy.ndarray, **kwargs) -> (float, numpy.ndarray): transformed = self.transform(image=x) x = transformed["image"] x = x.unsqueeze(0) if next(self.model.parameters()).is_cuda: x = x.cuda(device='cuda:0') logits_pred, masks_pred = self.model(x) - y_pred = torch.argmax(logits_pred, dim=1).detach().cpu().numpy() + probs_pred = F.softmax(logits_pred, dim=1) + fake_prob = probs_pred[:, 1].tolist()[0] + # y_pred = torch.argmax(logits_pred, dim=1).detach().cpu().numpy() masks_pred = masks_pred.squeeze().detach().cpu().numpy() - return y_pred, masks_pred + return fake_prob, masks_pred diff --git a/app/recognition/face_model.py b/app/recognition/face_model.py index cef61868848354b5ed69497827e0079ec64d8e2a..c9b7e7181b9876395dbc98b6e0990fef7a12e0d1 100644 --- a/app/recognition/face_model.py +++ b/app/recognition/face_model.py @@ -11,3 +11,6 @@ class FaceModel: def get_prediction(self, x: numpy.ndarray, **kwargs) -> numpy.ndarray: raise NotImplementedError + + def get_predictions(self, x: numpy.ndarray, **kwargs) -> (float, numpy.ndarray): + raise NotImplementedError diff --git a/app/recognition/facenet.py b/app/recognition/facenet.py index 9fc4c6d822217b4bbe2a63ab272dadf09765013a..152922ca05862e92cd0cb24f44a98d3437d58d61 100644 --- a/app/recognition/facenet.py +++ b/app/recognition/facenet.py @@ -48,3 +48,6 @@ class FaceNetTorch(FaceModel): def get_prediction(self, x: numpy.ndarray, **kwargs) -> numpy.ndarray: pass + + def get_predictions(self, x: numpy.ndarray, **kwargs) -> (float, numpy.ndarray): + pass diff --git a/app/utils/metrics_utils.py b/app/utils/metrics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..237854f5a5087d03cc8e633cd58ff90b364c5dbb --- /dev/null +++ b/app/utils/metrics_utils.py @@ -0,0 +1,10 @@ +import numpy as np + + +def get_video_fake_prob(fake_probs: np.ndarray): + real_num = len(fake_probs[fake_probs < 0.5]) + fake_num = len(fake_probs[fake_probs >= 0.5]) + if fake_num == 0 or real_num == 0: + return np.sum(fake_probs) / fake_probs.shape[0] + else: + return 0.5 * (1 + (np.sum(fake_probs) / fake_probs.shape[0])) diff --git a/sql/scheme.sql b/sql/scheme.sql index 583582a9514e262285a7e0c6892637beddcbff0a..d3289ad65c201ac3a68cfbf7a08c2a584dad1316 100644 --- a/sql/scheme.sql +++ b/sql/scheme.sql @@ -53,9 +53,10 @@ DROP TABLE IF EXISTS t_deepfake_video; CREATE TABLE t_deepfake_video ( id int NOT NULL AUTO_INCREMENT, - video_name varchar(512) NOT NULL COMMENT '视频名称', - status TINYINT NOT NULL COMMENT '状态:0-未检测,1-正在检测,2-完成检测', + video_name varchar(256) NOT NULL COMMENT '视频名称', + status TINYINT NOT NULL COMMENT '状态:0-未检测,1-正在检测,2-完成检测,3-执行失败', progress TINYINT NOT NULL COMMENT '检测进度[0, 100]', + probability FLOAT NOT NULL COMMENT '伪造概率', create_date datetime NOT NULL COMMENT '上传时间', PRIMARY KEY (`id`) ) ENGINE = InnoDB AUTO_INCREMENT = 1 CHARSET = utf8mb4; diff --git a/tests/test_env.py b/tests/test_cuda_env.py similarity index 95% rename from tests/test_env.py rename to tests/test_cuda_env.py index 26a1d77d2acbff798361ce8495e8dea314acba33..c30e85b4f9385098a4a2335e497a2ccdc5121bed 100644 --- a/tests/test_env.py +++ b/tests/test_cuda_env.py @@ -1,6 +1,5 @@ import unittest -import tensorflow.compat.v1 as tf import torch diff --git a/tests/test_deepfake_detection_sifdnet.py b/tests/test_deepfake_detection_sifdnet.py index e2de8f5d6a9cdf7964ca8239edf3409e80cb741d..347ff100e4620cb9faa601ef38018ea294cc1497 100644 --- a/tests/test_deepfake_detection_sifdnet.py +++ b/tests/test_deepfake_detection_sifdnet.py @@ -14,14 +14,14 @@ from app.utils.image_utils import ndarray_error_threshold, crop_and_resize class DeepfakeDetectionSIFDNetTestCase(unittest.TestCase): def test_prediction(self): - os.makedirs('../temp', exist_ok=True) - video_path = str(file_utils.VIDEO_FOLDER / '000_003.mp4') - deepfake_detection_sifdnet = DeepfakeDetectionSIFDNet() + video_path = str(file_utils.VIDEO_FOLDER / '000_003-1.mp4') + deepfake_detection_sifdnet = DeepfakeDetectionSIFDNet.instance() deepfake_detection_sifdnet.create_model() frames = file_utils.get_video_frames(video_path) detector = FacenetDetector(landmarks=False) mask_frames = [] + fake_probs = [] for index, frame in frames.items(): # h, w = frame.shape[:2] # if index > 50: @@ -30,7 +30,7 @@ class DeepfakeDetectionSIFDNetTestCase(unittest.TestCase): resized_image = Image.fromarray(frame) resized_image = resized_image.resize(size=(resized_image.size[0] // 2, resized_image.size[1] // 2)) bboxes = detector.detect_faces(resized_image) - # frame = np.asarray(frame) + frame_fake_prob_sum = 0. for bbox in bboxes: xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox] w = xmax - xmin @@ -38,20 +38,20 @@ class DeepfakeDetectionSIFDNetTestCase(unittest.TestCase): p_h = h // 3 p_w = w // 3 crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w] - # cv2.imwrite(f"../temp/{index}.png", crop) - y_pred, masks_pred = deepfake_detection_sifdnet.get_prediction(crop) + fake_prob, masks_pred = deepfake_detection_sifdnet.get_predictionss(crop) + frame_fake_prob_sum += fake_prob masks_pred = ndarray_error_threshold(masks_pred) masks_pred = crop_and_resize(masks_pred, crop) masks_pred = masks_pred * 255 masks_pred = np.expand_dims(masks_pred, axis=2) masks_pred = np.tile(masks_pred, (1, 1, 3)) - # masks_pred = np.ones_like(crop) * 255 mask_frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w] = masks_pred + fake_probs.append(frame_fake_prob_sum / len(bboxes)) mask_frames.append(mask_frame) shape = mask_frames[0].shape - video_tracked = cv2.VideoWriter('video_tracked.mp4', cv2.VideoWriter_fourcc(*'MP4V'), 25.0, - (shape[1], shape[0])) + video_path = str(file_utils.VIDEO_FOLDER / 'video_tracked.webm') + video_tracked = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*"vp80"), 25.0, (shape[1], shape[0])) for frame in mask_frames: video_tracked.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) video_tracked.release() diff --git a/tests/test_metrics_utils.py b/tests/test_metrics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a96f255ee69c3a5fad4362905160424fc816bd6 --- /dev/null +++ b/tests/test_metrics_utils.py @@ -0,0 +1,44 @@ +import unittest + +import numpy as np + +from app.utils import metrics_utils + + +class MetricsUtilsTestCase(unittest.TestCase): + + def test_get_video_fake_probability(self): + probs = [0.9] * 5 + probs = np.array(probs) + result = metrics_utils.get_video_fake_prob(probs) + self.assertEqual(0.9, result) + + probs = [0.1] * 5 + probs = np.array(probs) + result = metrics_utils.get_video_fake_prob(probs) + self.assertEqual(0.1, result) + + probs = [0.1, 0.7, 0.8, 0.9] + probs = np.array(probs) + result = metrics_utils.get_video_fake_prob(probs) + self.assertEqual(0.8125, result) + + probs = [0.1, 0.2, 0.8, 0.9] + probs = np.array(probs) + result = metrics_utils.get_video_fake_prob(probs) + self.assertEqual(0.75, result) + + probs = [0.1, 0.1, 0.1, 0.6] + probs = np.array(probs) + result = metrics_utils.get_video_fake_prob(probs) + self.assertEqual(0.6125, result) + + probs = [0.1, 0.1, 0.1, 0.7] + probs = np.array(probs) + result = metrics_utils.get_video_fake_prob(probs) + self.assertEqual(0.625, result) + + probs = [0.1, 0.1, 0.1, 0.8] + probs = np.array(probs) + result = metrics_utils.get_video_fake_prob(probs) + self.assertEqual(0.6375, result) diff --git a/tests/video_tracked.mp4 b/tests/video_tracked.mp4 deleted file mode 100644 index 0e722e40c8d2827bdd8a83ee3ae639548d59be01..0000000000000000000000000000000000000000 Binary files a/tests/video_tracked.mp4 and /dev/null differ