From ec5d5adc004d43b7fb4e0c9f22daab7a2449de16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=99=E5=B0=91=E5=8D=8E?= <5457511+meng_xiao_hua@user.noreply.gitee.com> Date: Tue, 11 Jan 2022 15:35:59 +0800 Subject: [PATCH] feat: add vision_transformer_attention_fusion_qkv.py --- .gitignore | 1 + app/api/face.py | 10 + .../anti_spoofing_two_stream_vit.py | 52 ++ .../{face_recognition.py => face_model.py} | 6 +- app/recognition/facenet.py | 7 +- app/recognition/transform.py | 87 +++ ...vision_transformer_attention_fusion_qkv.py | 501 ++++++++++++++++++ app/utils/config_msr.json | 9 + app/utils/retinex.py | 126 +++++ config.py | 2 +- requirements.txt | 1 + 11 files changed, 798 insertions(+), 4 deletions(-) create mode 100644 app/recognition/anti_spoofing_two_stream_vit.py rename app/recognition/{face_recognition.py => face_model.py} (69%) create mode 100644 app/recognition/transform.py create mode 100644 app/recognition/vision_transformer_attention_fusion_qkv.py create mode 100644 app/utils/config_msr.json create mode 100644 app/utils/retinex.py diff --git a/.gitignore b/.gitignore index e31cee4..a3e3c8e 100644 --- a/.gitignore +++ b/.gitignore @@ -103,6 +103,7 @@ ENV/ # Custom .idea/ logs/ +weights/ 20170512-110547/ 20180402-114759/ 20180408-102900/ diff --git a/app/api/face.py b/app/api/face.py index 33b7681..bdf1bbd 100644 --- a/app/api/face.py +++ b/app/api/face.py @@ -1,6 +1,7 @@ import logging import pickle +import cv2 import numpy as np from flask import jsonify, request from werkzeug.exceptions import BadRequest @@ -10,6 +11,8 @@ from app.model.models import User, Photo, FaceModel from app.recognition import compare from app.utils import file_utils from . import bp +from app.utils import retinex +from app.recognition.vision_transformer_attention_fusion_qkv import vit_base_resnet50_224_in21k LOG = logging.getLogger(__name__) @@ -21,6 +24,13 @@ def face_recognition(): raise BadRequest('Must include photo.') photo = data['photo'] input_image = file_utils.resize_blob_to_160x160(photo) + # 这里之后加活体检测代码 + input_image_msrcr = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY) + input_image_msrcr = cv2.cvtColor(input_image_msrcr, cv2.COLOR_GRAY2BGR) + input_image_msrcr = retinex.MSRCR(input_image_msrcr, sigma_list=[15, 80, 250], G=5.0, b=25.0, alpha=125.0, beta=46.0, low_clip=0.01, high_clip=0.99) + is_true = vit_base_resnet50_224_in21k(input_image, input_image_msrcr) + + embedding = face_net_torch.get_embeddings(input_image) photos = Photo.query.all() x, y = [], [] diff --git a/app/recognition/anti_spoofing_two_stream_vit.py b/app/recognition/anti_spoofing_two_stream_vit.py new file mode 100644 index 0000000..b9b4f1a --- /dev/null +++ b/app/recognition/anti_spoofing_two_stream_vit.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import re + +import numpy +import torch + +import config +from .face_model import FaceModel +from .transform import create_val_test_transform +from .vision_transformer_attention_fusion_qkv import vit_base_resnet50_224_in21k + +LOG = logging.getLogger(__name__) + + +class AntiSpoofingWtoStreamVit(FaceModel): + + def __init__(self, app=None): + super().__init__() + self.model = vit_base_resnet50_224_in21k(pretrained=False, in_chans=3, num_classes=2) + self.transform = create_val_test_transform(model_cfg=self.model.default_cfg) + if app is not None: + self.init_app(app) + + def init_app(self, app): + self.create_model() + + def create_model(self): + # set TORCH_HOME in your os environment + + model_path = config.get_config('ANTI_SPOOFING_TWO_STREAM_VIT_PATH') + checkpoint = torch.load(model_path, map_location="cpu") + state_dict = checkpoint.get("state_dict", checkpoint) + self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False) + if torch.cuda.is_available(): + self.model.cuda() + self.model.eval() + + def get_embeddings(self, input_image: numpy.ndarray): + pass + + def get_y_prediction(self, input_image) -> numpy.ndarray: + x = self.transform(input_image) + x = x.unsqueeze(0) + if torch.cuda.is_available(): + x = x.cuda() + logits = self.model(x) + y_preds = torch.argmax(logits, dim=1) + return y_preds diff --git a/app/recognition/face_recognition.py b/app/recognition/face_model.py similarity index 69% rename from app/recognition/face_recognition.py rename to app/recognition/face_model.py index 3d804b4..fbfc6f6 100644 --- a/app/recognition/face_recognition.py +++ b/app/recognition/face_model.py @@ -1,7 +1,8 @@ import numpy -class FaceRecognition: +class FaceModel: + __shared_state = {} def __init__(self): @@ -12,3 +13,6 @@ class FaceRecognition: def get_embeddings(self, input_image) -> numpy.ndarray: raise NotImplementedError + + def get_y_prediction(self, input_image) -> numpy.ndarray: + raise NotImplementedError diff --git a/app/recognition/facenet.py b/app/recognition/facenet.py index 349a1fc..55c4c40 100644 --- a/app/recognition/facenet.py +++ b/app/recognition/facenet.py @@ -9,12 +9,12 @@ import torch from facenet_pytorch import InceptionResnetV1 from torchvision.transforms import transforms -from .face_recognition import FaceRecognition +from .face_model import FaceModel LOG = logging.getLogger(__name__) -class FaceNetTorch(FaceRecognition): +class FaceNetTorch(FaceModel): def __init__(self, app=None): super().__init__() @@ -44,3 +44,6 @@ class FaceNetTorch(FaceRecognition): embeddings = self.model(x) embeddings = embeddings.detach().cpu().numpy() return embeddings[0] + + def get_y_prediction(self, input_image) -> numpy.ndarray: + pass diff --git a/app/recognition/transform.py b/app/recognition/transform.py new file mode 100644 index 0000000..5318ee6 --- /dev/null +++ b/app/recognition/transform.py @@ -0,0 +1,87 @@ +import cv2 +from albumentations import DualTransform +from albumentations.pytorch import ToTensorV2 +import albumentations as A + + +def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC): + h, w = img.shape[:2] + if max(w, h) == size: + return img + if w > h: + scale = size / w + h = h * scale + w = size + else: + scale = size / h + w = w * scale + h = size + interpolation = interpolation_up if scale > 1 else interpolation_down + resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation) + return resized + + +class IsotropicResize(DualTransform): + + def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, + always_apply=False, p=1): + super(IsotropicResize, self).__init__(always_apply, p) + self.max_side = max_side + self.interpolation_down = interpolation_down + self.interpolation_up = interpolation_up + + def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params): + return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down, + interpolation_up=interpolation_up) + + def apply_to_mask(self, img, **params): + return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params) + + def get_transform_init_args_names(self): + return "max_side", "interpolation_down", "interpolation_up" + + def apply_to_bbox(self, bbox, **params): + pass + + def apply_to_keypoint(self, keypoint, **params): + pass + + def get_params_dependent_on_targets(self, params): + pass + + +def create_train_transform(model_cfg): + size = model_cfg['input_size'][1] + mean = model_cfg['mean'] + std = model_cfg['std'] + return A.Compose([ + A.HorizontalFlip(), + # A.Resize(height=224, width=224), + # A.Resize(height=128, width=128), + # A.RandomCrop(height=114, width=114), + A.OneOf([ + IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), + IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR), + IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR), + ], p=1), + A.PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT, value=0), + # A.RandomCrop(height=input_size[1], width=input_size[2]), + # A.OneOf([A.RandomBrightnessContrast(), A.FancyPCA(), A.HueSaturationValue()], p=0.5), + # A.ToGray(p=0.2), + # A.ShiftScaleRotate(rotate_limit=(-20, 20), border_mode=cv2.BORDER_CONSTANT, p=0.5), + A.Normalize(mean=mean, std=std), + ToTensorV2(), + ], additional_targets={'image2': 'image', 'image3': 'image'}) + + +def create_val_test_transform(model_cfg): + size = model_cfg['input_size'][1] + mean = model_cfg['mean'] + std = model_cfg['std'] + return A.Compose([ + IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), + A.PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT, value=0), + # A.Resize(height=224, width=224), + A.Normalize(mean=mean, std=std), + ToTensorV2(), + ], additional_targets={'image2': 'image', 'image3': 'image'}) diff --git a/app/recognition/vision_transformer_attention_fusion_qkv.py b/app/recognition/vision_transformer_attention_fusion_qkv.py new file mode 100644 index 0000000..bacbc9d --- /dev/null +++ b/app/recognition/vision_transformer_attention_fusion_qkv.py @@ -0,0 +1,501 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import logging +from functools import partial +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import load_pretrained +from timm.models.layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_ +from timm.models.resnet import resnet26d, resnet50d +from timm.models.resnetv2 import ResNetV2 +from timm.models.registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'vit_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + ), + + # patch models (weights ported from official Google JAX impl) + 'vit_base_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), + 'vit_base_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + + # hybrid models (weights ported from official Google JAX impl) + 'vit_base_resnet50_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, + first_conv='patch_embed.backbone.stem.conv'), + +} + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + # print("q.shape: ", q.shape, "q: ", q) + # print("k.shape: ", k.shape, "k: ", k) + # print("v.shape: ", v.shape, "v: ", v) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Conv2d(feature_dim, embed_dim, 1) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_heads = num_heads + head_dim = embed_dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # # Representation layer + # if representation_size: + # self.num_features = representation_size + # self.pre_logits = nn.Sequential(OrderedDict([ + # ('fc', nn.Linear(embed_dim, representation_size)), + # ('act', nn.Tanh()) + # ])) + # else: + # self.pre_logits = nn.Identity() + + # Classifier head + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + # self.q_kernel = nn.Parameter(torch.ones(1, 768)) + # self.v_kernel = nn.Parameter(torch.ones(1)) + self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + # x = self.norm(x)[:, 0] + x = self.norm(x)[:, 0] + # x = self.pre_logits(x) + return x + + def forward(self, x_rgb, x_msr): + x_rgb = self.forward_features(x_rgb) + x_msr = self.forward_features(x_msr) + x_concate = torch.cat((x_rgb.unsqueeze(1), x_msr.unsqueeze(1)), 1) + B, N, C = x_concate.shape + qkv = self.qkv(x_concate).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = x[:, 0, :] + x[:, 1, :] + x = x.squeeze() + x = self.head(x) + return x + + +class DistilledVisionTransformer(VisionTransformer): + """ Vision Transformer with distillation token. + + Paper: `Training data-efficient image transformers & distillation through attention` - + https://arxiv.org/abs/2012.12877 + + This impl of distilled ViT is taken from https://github.com/facebookresearch/deit + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() + + trunc_normal_(self.dist_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + self.head_dist.apply(self._init_weights) + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0], x[:, 1] + + def forward(self, x): + x, x_dist = self.forward_features(x) + x = self.head(x) + x_dist = self.head_dist(x_dist) + if self.training: + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 + + +def resize_pos_embed(posemb, posemb_new): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if True: + posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + ntok_new -= 1 + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + gs_new = int(math.sqrt(ntok_new)) + _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed(v, model.pos_embed) + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs): + default_cfg = default_cfgs[variant] + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-1] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + # Remove representation layer if fine-tuning. This may not always be the desired action, + # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? + _logger.warning("Removing representation layer for fine-tuning.") + repr_size = None + + model_cls = DistilledVisionTransformer if distilled else VisionTransformer + model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs) + model.default_cfg = default_cfg + + if pretrained: + load_pretrained( + model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3), + filter_fn=partial(checkpoint_filter_fn, model=model)) + return model + + +@register_model +def vit_small_patch16_224(pretrained=False, **kwargs): + """ My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3.""" + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., + qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs) + if pretrained: + # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model + model_kwargs.setdefault('qk_scale', 768 ** -0.5) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head + backbone = ResNetV2( + layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), + preact=False, stem_type='same', conv_layer=StdConv2dSame) + model_kwargs = dict( + embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, + representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs) + return model diff --git a/app/utils/config_msr.json b/app/utils/config_msr.json new file mode 100644 index 0000000..03757cc --- /dev/null +++ b/app/utils/config_msr.json @@ -0,0 +1,9 @@ +{ + "sigma_list": [15, 80, 250], + "G" : 5.0, + "b" : 25.0, + "alpha" : 125.0, + "beta" : 46.0, + "low_clip" : 0.01, + "high_clip" : 0.99 +} diff --git a/app/utils/retinex.py b/app/utils/retinex.py new file mode 100644 index 0000000..2ea29c2 --- /dev/null +++ b/app/utils/retinex.py @@ -0,0 +1,126 @@ +import numpy as np +import cv2 + + +def singleScaleRetinex(img, sigma): + retinex = np.log10(img) - np.log10(cv2.GaussianBlur(img, (0, 0), sigma)) + + return retinex + + +def multiScaleRetinex(img, sigma_list): + retinex = np.zeros_like(img) + for sigma in sigma_list: + retinex += singleScaleRetinex(img, sigma) + + retinex = retinex / len(sigma_list) + + return retinex + + +def colorRestoration(img, alpha, beta): + img_sum = np.sum(img, axis=2, keepdims=True) + + color_restoration = beta * (np.log10(alpha * img) - np.log10(img_sum)) + + return color_restoration + + +def simplestColorBalance(img, low_clip, high_clip): + total = img.shape[0] * img.shape[1] + for i in range(img.shape[2]): + unique, counts = np.unique(img[:, :, i], return_counts=True) + current = 0 + for u, c in zip(unique, counts): + if float(current) / total < low_clip: + low_val = u + if float(current) / total < high_clip: + high_val = u + current += c + + img[:, :, i] = np.maximum(np.minimum(img[:, :, i], high_val), low_val) + + return img + + +def MSRCR(img, sigma_list=None, G=5.0, b=25.0, alpha=125.0, beta=46.0, low_clip=0.01, high_clip=0.99): + if sigma_list is None: + sigma_list = [15, 80, 250] + img = np.float64(img) + 1.0 + + img_retinex = multiScaleRetinex(img, sigma_list) + img_color = colorRestoration(img, alpha, beta) + img_msrcr = G * (img_retinex * img_color + b) + + for i in range(img_msrcr.shape[2]): + img_msrcr[:, :, i] = (img_msrcr[:, :, i] - np.min(img_msrcr[:, :, i])) / \ + (np.max(img_msrcr[:, :, i]) - np.min(img_msrcr[:, :, i])) * \ + 255 + + img_msrcr = np.uint8(np.minimum(np.maximum(img_msrcr, 0), 255)) + img_msrcr = simplestColorBalance(img_msrcr, low_clip, high_clip) + + return img_msrcr + + +def automatedMSRCR(img, sigma_list): + img = np.float64(img) + 1.0 + + img_retinex = multiScaleRetinex(img, sigma_list) + + for i in range(img_retinex.shape[2]): + unique, count = np.unique(np.int32(img_retinex[:, :, i] * 100), return_counts=True) + for u, c in zip(unique, count): + if u == 0: + zero_count = c + break + + low_val = unique[0] / 100.0 + high_val = unique[-1] / 100.0 + for u, c in zip(unique, count): + if u < 0 and c < zero_count * 0.1: + low_val = u / 100.0 + if u > 0 and c < zero_count * 0.1: + high_val = u / 100.0 + break + + img_retinex[:, :, i] = np.maximum(np.minimum(img_retinex[:, :, i], high_val), low_val) + + img_retinex[:, :, i] = (img_retinex[:, :, i] - np.min(img_retinex[:, :, i])) / \ + (np.max(img_retinex[:, :, i]) - np.min(img_retinex[:, :, i])) \ + * 255 + + img_retinex = np.uint8(img_retinex) + + return img_retinex + + +def MSRCP(img, sigma_list, low_clip, high_clip): + img = np.float64(img) + 1.0 + + intensity = np.sum(img, axis=2) / img.shape[2] + + retinex = multiScaleRetinex(intensity, sigma_list) + + intensity = np.expand_dims(intensity, 2) + retinex = np.expand_dims(retinex, 2) + + intensity1 = simplestColorBalance(retinex, low_clip, high_clip) + + intensity1 = (intensity1 - np.min(intensity1)) / \ + (np.max(intensity1) - np.min(intensity1)) * \ + 255.0 + 1.0 + + img_msrcp = np.zeros_like(img) + + for y in range(img_msrcp.shape[0]): + for x in range(img_msrcp.shape[1]): + B = np.max(img[y, x]) + A = np.minimum(256.0 / B, intensity1[y, x, 0] / intensity[y, x, 0]) + img_msrcp[y, x, 0] = A * img[y, x, 0] + img_msrcp[y, x, 1] = A * img[y, x, 1] + img_msrcp[y, x, 2] = A * img[y, x, 2] + + img_msrcp = np.uint8(img_msrcp - 1.0) + + return img_msrcp diff --git a/config.py b/config.py index 853601d..0571b61 100644 --- a/config.py +++ b/config.py @@ -2,7 +2,7 @@ import os from dotenv import load_dotenv -CONFIG_OPTIONS = ['FACE_NET_MODEL_PATH', 'FACE_NET_TORCH_HOME', 'FACE_DATA_PATH'] +CONFIG_OPTIONS = ['FACE_NET_MODEL_PATH', 'FACE_NET_TORCH_HOME', 'FACE_DATA_PATH', 'ANTI_SPOOFING_TWO_STREAM_VIT_PATH'] CONFIGS = {} BASE_DIR = os.path.abspath(os.path.dirname(__file__)) diff --git a/requirements.txt b/requirements.txt index f493847..eeb8f7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ python-dotenv scikit-learn tensorflow==1.15 werkzeug +albumentations -- Gitee