diff --git a/backend/exceptions.py b/backend/exceptions.py index 96da68d3e3a09676cd6c6f543383cd977bfe0ae3..253d3c33d99f09f8bf20f8734a64e58af88cffe3 100644 --- a/backend/exceptions.py +++ b/backend/exceptions.py @@ -1,7 +1,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.db import DataError from rest_framework import status -from rest_framework.exceptions import ValidationError, APIException, PermissionDenied +from rest_framework.exceptions import ValidationError, PermissionDenied from rest_framework.response import Response from rest_framework.views import exception_handler @@ -16,28 +16,29 @@ def custom_exception_handler(exc, handler): 'error_msg': '数据库异常', }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - if isinstance(exc, Exception): - response = Response(data={ - 'code': 500, - 'msg': '有未处理的异常,请联系记录相关信息并联系后端', - 'error_msg': exc.args, - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - if isinstance(exc, ValidationError): + elif isinstance(exc, ValidationError): response = Response(data={ 'code': 400, 'error_msg': exc.detail }, status=status.HTTP_400_BAD_REQUEST) - if isinstance(exc, ObjectDoesNotExist): + elif isinstance(exc, ObjectDoesNotExist): response = Response(data={ 'code': 400, 'error_msg': exc.args, }, status=status.HTTP_400_BAD_REQUEST) - if isinstance(exc, PermissionDenied): + elif isinstance(exc, PermissionDenied): response = Response(data={ 'code': 403, 'error_msg': exc.detail }, status=status.HTTP_403_FORBIDDEN) + + elif isinstance(exc, Exception): + response = Response(data={ + 'code': 500, + 'msg': '有未处理的异常,请联系记录相关信息并联系后端', + 'error_msg': exc.args, + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + return response diff --git a/backend/helper.py b/backend/helper.py index ba184fa1148babc6f8a73b98a7c6b366f175c34c..0cc16becc4a20b5bfc2a4fbbfa2e1161e7e389be 100644 --- a/backend/helper.py +++ b/backend/helper.py @@ -43,9 +43,26 @@ class MyModelViewSet(viewsets.ModelViewSet): return Response(data=data, status=status.HTTP_200_OK) -# TODO: Custom response with api version. -class VersionResponse: - pass +class APIResponse(Response): + def __init__(self, success=True, code=200, + msg='success', data=None, status=None, + headers=None, content_type=None, **kwargs): + dict = { + 'success': success, + 'code': code, + 'msg': msg + } + if data is not None: + dict['data'] = data + + dict.update(kwargs) + + super().__init__( + data=dict, + status=status, + template_name=None, headers=headers, + exception=False, content_type=content_type + ) class ChoiceValueFilter(): diff --git a/backend/settings.py b/backend/settings.py index 6ccf49216685b196f36419f75921e5e765b9503f..6b892882d9b9f342aaa55c121569cfb3f69606ad 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -40,6 +40,9 @@ INSTALLED_APPS = [ 'django.contrib.messages', 'django.contrib.staticfiles', + # neo4j, must before apps. + 'django_neomodel', + # --- django dependencies --- 'corsheaders', 'django_filters', @@ -68,6 +71,9 @@ INSTALLED_APPS = [ ] MIDDLEWARE = [ + # Caches, order matters, update must be the first one. + 'django.middleware.cache.UpdateCacheMiddleware', + 'corsheaders.middleware.CorsMiddleware', # --- django --- @@ -80,8 +86,15 @@ MIDDLEWARE = [ 'django.middleware.clickjacking.XFrameOptionsMiddleware', 'backend.middleware.NotUseCsrfTokenMiddlewareMixin', + + # FetchFromCacheMiddleware caches `GET` `HEAD` with status 200, + # must be the last one. + 'django.middleware.cache.FetchFromCacheMiddleware', ] +CACHE_MIDDLEWARE_KEY_PREFIX = "cache_redis_middleware" +SESSION_ENGINE = 'django.contrib.sessions.backends.cache' + ROOT_URLCONF = 'backend.urls' TEMPLATES = [ @@ -116,8 +129,43 @@ DATABASES = { 'HOST': '127.0.0.1', 'PORT': '3306' }, + 'neo4j': { + 'ENGINE': '', + 'NAME': 'neo4j', + 'USER': 'neo4j', + 'PASSWORD': 'neo4j', + 'HOST': '127.0.0.1', + 'PORT': '7687', + 'TEST': { + 'NAME': 'test_neo4j', + } + } } +# Cache +CACHES = { + 'default': { + 'BACKEND': "django_redis.cache.RedisCache", + 'LOCATION': "redis://127.0.0.1:6379", + 'OPTIONS': { + 'CLIENT_CLASS': 'django_redis.client.DefaultClient', + 'CONNECTION_POOL_KWARGS': { + 'max_connections': 20 + } + } + } +} +# Neo4j Qy3fbDPjv_bcp-CjVH6HW7d6vjC8YOmVYOU8LpbCfmE neo4j+s://f954cb26.databases.neo4j.io +# username neo4j password neo4j +NEO4J_BOLT_URL = 'bolt://neo4j:neo4j@localhost:7687' +NEOMODEL_NEO4J_BOLT_URL = os.environ.get( + 'NEO4J_BOLT_URL', + 'bolt://neo4j:neo4j@localhost:7687' +) +NEOMODEL_SIGNALS = True +NEOMODEL_FORCE_TIMEZONE = False +NEOMODEL_ENCRYPTED_CONNECTION = False +NEOMODEL_MAX_CONNECTION_POOL_SIZE = 50 # Password validation # https://docs.djangoproject.com/en/3.2/ref/settings/#auth-password-validators @@ -157,7 +205,7 @@ USE_TZ = False # https://docs.djangoproject.com/en/3.2/howto/static-files/ STATIC_URL = '/static/' -STATIC_ROOT = BASE_DIR/'static' +STATIC_ROOT = BASE_DIR / 'static' STATICFILES_DIRS = [ os.path.join(BASE_DIR, '/static/'), @@ -222,4 +270,8 @@ APIKEY = "NONE" # ) - +# CACHE_MIDDLEWARE_ALIAS -- 用于存储的缓存别名 +# CACHE_MIDDLEWARE_SECONDS -- 应缓存每个页面的秒数 +# 如果使用相同的 Django installation ,通过多站点进行缓存共享,设置为站点名 +# 或者设置成在 Django 实例中唯一的其他字符串,以此防止键冲突。必要时可以设置成空字符串。 +# CACHE_MIDDLEWARE_KEY_PREFIX diff --git a/index.py b/index.py index 1ad7946558833a6c5c87c9513771f008dc999e8e..cc8b5cd09599e4d2040bd8acadca0478f97a6c26 100644 --- a/index.py +++ b/index.py @@ -11,7 +11,7 @@ prefix = ('https://api.qiusuo-mc.cn/' @permission_classes([AllowAny]) def root(request): return Response({ - 'api-version': '1.2', + 'api-version': '1.3', 'swagger-doc': prefix + 'doc/', 'docs': prefix + 'docs/', 'operation-manage': prefix + 'operation-manage/', diff --git a/operations/serializers.py b/operations/serializers.py index bc6de78036085cbbe5d9309a6379494c3a1df339..87859ab7bb69156537985db00f282ba1409498f5 100644 --- a/operations/serializers.py +++ b/operations/serializers.py @@ -3,7 +3,6 @@ from datetime import datetime, timedelta from django.contrib.auth.hashers import make_password from django.contrib.auth.models import User -from django.db.models import Q from django.utils import timezone from rest_framework import serializers from rest_framework.exceptions import ValidationError @@ -11,7 +10,7 @@ from rest_framework_jwt.serializers import jwt_payload_handler, jwt_encode_handl from backend import constants from operations.models import VerifyCode -from users.models import UserInfo +from users.models import UserInfo, UserNode from users.serializers import UserProfileSerializer @@ -228,44 +227,63 @@ class FileUploadSerializer(serializers.Serializer): fields = ['file'] -class FollowSerializer(serializers.ModelSerializer): - follow = serializers.PrimaryKeyRelatedField( - queryset=UserInfo.objects.all() - ) - - class Meta: - model = UserInfo - fields = ['id', 'follow'] +class FollowSerializer(serializers.Serializer): @property def user(self): return self.context['request'].user def validate(self, attrs): - follow = attrs.get('follow') - attrs['fan'] = self.user - if not follow: - raise serializers.ValidationError("用户不存在") + id = self.initial_data.get('id') + if not id: + raise ValidationError("id 不能为空!") + + if id == self.user.id: + raise ValidationError("不能关注自己,请检查 id 是否有误") + + # 防御性 + try: + target = UserInfo.objects.get(id=id) + except UserInfo.DoesNotExist: + raise ValidationError("用户不存在,请检查 id 是否有误") + + attrs['id'] = id + attrs['target'] = target + + from operations.utils import get_or_create + user_node = get_or_create( + user_id=self.user.id, + username=self.user.username, + avatar=self.user.userinfo.avatar + ) + target_node = get_or_create( + user_id=id, + username=target.username, + avatar=target.userinfo.avatar + ) + + attrs['user_node'] = user_node + attrs['target_node'] = target_node return attrs def create(self, validated_data): - validated_data['follow'].fans.add(validated_data['fan'].userinfo) - - def update(self, instance, validated_data): - instance.fans.remove(validated_data['fan']) - - def save(self): - master = self.validated_data.get('follow') - fan = self.validated_data.get('fan') - # 当前用户是否关注了目标用户? - has_follow = UserInfo.objects.filter( - Q(id=master.id) & Q(fans=fan.id) - ).count() - - if has_follow == 0: - self.create(self.validated_data) + id = validated_data['id'] + target = validated_data['target'] + user_node: UserNode = validated_data['user_node'] + target_node: UserNode = validated_data['target_node'] + + assert user_node != target_node + + try: + find = user_node.follows.get(user_id=id) + except user_node.DoesNotExist: + # 未关注。 + user_node.follows.connect(target_node) + target_node.fans.connect(user_node) + except Exception: + raise Exception() else: - self.update(master, self.validated_data) - master.save() - - return 'update' if has_follow != 0 else 'create' + # 已关注,本次请求取消关注。 + user_node.follows.disconnect(target_node) + target_node.fans.disconnect(user_node) + return True diff --git a/operations/tests.py b/operations/tests.py index a39b155ac3ee946fb97efafe6ecbb42f571cd7ad..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/operations/tests.py +++ b/operations/tests.py @@ -1 +0,0 @@ -# Create your tests here. diff --git a/operations/utils.py b/operations/utils.py index 2e6b630567b74d40b5f58c984947b2ffc6ba59fc..9e4f367770390ad657b48aa854cb7ef95d0787d2 100644 --- a/operations/utils.py +++ b/operations/utils.py @@ -1,3 +1,8 @@ +from neomodel import db + +from users.models import UserNode + + def jwt_response_payload_handler(token, user=None, request=None): """ 自定义jwt认证成功返回参数 @@ -7,3 +12,20 @@ def jwt_response_payload_handler(token, user=None, request=None): 'user': user.id, 'username': user.username } + + +@db.transaction +def get_or_create(user_id: int, username: str, avatar: str) -> UserNode: + try: + n = UserNode.nodes.get( + user_id=user_id, + username=username, + avatar=avatar + ) + except UserNode.DoesNotExist: + n = UserNode( + user_id=user_id, + username=username, + avatar=avatar + ).save() + return n diff --git a/operations/views/user_operations/follow.py b/operations/views/user_operations/follow.py index eb70dc60f146b93b02998c29a1bd09135abde26c..89d990ca5378d2e8b9b320b2d0125bb03c46132c 100644 --- a/operations/views/user_operations/follow.py +++ b/operations/views/user_operations/follow.py @@ -1,33 +1,25 @@ from rest_framework import mixins from rest_framework import permissions, status from rest_framework.decorators import permission_classes -from rest_framework.response import Response from rest_framework.viewsets import GenericViewSet +from backend.helper import Response from operations.serializers import FollowSerializer from users.models import UserInfo @permission_classes([permissions.IsAuthenticated]) -class FollowAPIView(GenericViewSet, - mixins.CreateModelMixin, - mixins.UpdateModelMixin): +class FollowAPIView(GenericViewSet, mixins.CreateModelMixin): queryset = UserInfo.objects.all() serializer_class = FollowSerializer def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - action = self.perform_create(serializer) + self.perform_create(serializer) headers = self.get_success_headers(serializer.data) - return Response(data={ 'success': True, - 'code': 201, - 'user': self.request.user.id, - 'msg': '成功关注' if action == 'create' else '取消关注', - 'target_user': serializer.data['follow'] - }, status=status.HTTP_201_CREATED, headers=headers) - - def perform_create(self, serializer): - return serializer.save() + 'code': 200, + 'msg': 'success' + }, status=status.HTTP_200_OK, headers=headers) diff --git a/requirements.txt b/requirements.txt index bd85f854562d6be5b879487a82dffe347b290c13..a81962ff2056de700227cd516ba030a4978b7d91 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/users/models.py b/users/models.py index cc0c6e7e3034ff91c37d56240250ba877e3e353a..7ffec2756d1c67ac622b00f0ca2b40aede23aa01 100644 --- a/users/models.py +++ b/users/models.py @@ -1,6 +1,11 @@ +from datetime import datetime + +import neomodel +import pytz from django.contrib import admin from django.contrib.auth.models import User from django.db import models +from neomodel import StructuredNode, StructuredRel, ZeroOrMore SEX_CHOICE = ((1, u'男'), (2, u'女'), (3, u'未知')) @@ -146,3 +151,33 @@ class UserTitle(models.Model): db_table = 'UserTitle' verbose_name = '用户头衔' verbose_name_plural = verbose_name + + +class FollowRel(StructuredRel): + since = neomodel.DateTimeProperty( + default=lambda: datetime.now(pytz.utc) + ) + + +class UserNode(StructuredNode): + uid = neomodel.UniqueIdProperty() + user_id = neomodel.IntegerProperty(unique_index=True, required=True) + username = neomodel.StringProperty(unique_index=True, required=True) + avatar = neomodel.StringProperty( + default='https://img.qiusuo-mc.cn/media/avatar/825ed0c3b35e053013344ee5ade03458.png' + ) + fans = neomodel.RelationshipTo( + 'UserNode', + 'FANS', + cardinality=ZeroOrMore, + model=FollowRel + ) + follows = neomodel.RelationshipTo( + 'UserNode', + 'FOLLOWS', + cardinality=ZeroOrMore, + model=FollowRel + ) + + class Meta: + app_label = 'user_relation' diff --git a/users/serializers.py b/users/serializers.py index ec552a6e63650c661be9e9a4dad9f0f491d942f7..06f8c6bd7795f819a2b774f4c28480ccec0691f6 100644 --- a/users/serializers.py +++ b/users/serializers.py @@ -187,60 +187,7 @@ class FanFollowDetailSerializer(serializers.ModelSerializer): class UserFansSerializer(serializers.ModelSerializer): - fans = serializers.SerializerMethodField() - fans_count = 0 - follows = serializers.SerializerMethodField() - follows_count = 0 class Meta: model = UserInfo - fields = ['id', 'username', 'fans', 'follows'] - - def get_fans(self, obj): - fans = UserInfo.objects.raw( - ''' - SELECT * - FROM - local_qiusuo.UserInfo as A, local_qiusuo.UserInfo_fans as uf - WHERE - uf.from_userinfo_id=%s AND uf.to_userinfo_id=A.user_ptr_id - ORDER BY - uf.created_at DESC - ''', [obj.id] - ) - self.fans_count = len(fans) - fans_dict = [] - for fan in fans: - ser_data = FanFollowDetailSerializer(fan).data - # The `fan.id` we get is actually from table `userinfo_fans`.`id` - # which is only a pk with no actual meaning. - # In order to display a meaningful id, - # we assign `user_ptr_id` to the `id` field. - ser_data['id'] = fan.user_ptr_id - fans_dict.append(ser_data) - return fans_dict - - def get_follows(self, obj): - follows = UserInfo.objects.raw( - ''' - SELECT A.user_ptr_id - FROM - local_qiusuo.UserInfo as A, local_qiusuo.UserInfo_fans as uf - WHERE - uf.to_userinfo_id=%s AND uf.from_userinfo_id=A.user_ptr_id - ORDER BY uf.created_at DESC - ''', [obj.id] - ) - self.follows_count = len(follows) - follows_dict = [] - for follow in follows: - ser_data = FanFollowDetailSerializer(follow).data - ser_data['id'] = follow.user_ptr_id - follows_dict.append(ser_data) - return follows_dict - - def to_representation(self, instance): - res = super().to_representation(instance) - res['fans_count'] = self.fans_count - res['follows_count'] = self.follows_count - return res + fields = ['id', 'username'] diff --git a/users/tests.py b/users/tests.py index a6677f5642208f15b15c6df012bae4161fcf66c8..eb0436e90255215eda8e90db0af35957e60ae16c 100644 --- a/users/tests.py +++ b/users/tests.py @@ -1,87 +1,156 @@ -import re # 正则 -from collections import OrderedDict # 创建有序字典 - -from django.conf import settings # 导入setting配置文件(获取项目根路径urls.py) -from django.urls import URLPattern, URLResolver # Django自定义的类,可以判断当前URL是否为根路径(URLPattern)或继续向下分发(URLResolver ) -from django.utils.module_loading import import_string # 字符串导入模块 - - -def check_url_exclude(url): - """ - 排除一些特定的URL - :param url: 待检验的URL - :return: - """ - for regex in settings.AUTO_DISCOVER_EXCLUDE: # 将要定向排除的URL(可包含正则) 按照列表的形式写入配置文件 settings.AUTO_DISCOVER_EXCLUDE 下 - if re.match(regex, url): - return True - - -""" -示例: -AUTO_DISCOVER_EXCLUDE = [ - '/admin/.*', - '/login/', - '/logout/', - '/index/', -] -""" - - -def recursion_urls(pre_namespace, pre_url, urlpatterns, url_ordered_dict): - """ - 递归的去获取URL - :param pre_namespace: namespace前缀,用于拼接name - :param pre_url: url前缀,用于拼接url - :param urlpatterns: 路由关系列表 - :param url_ordered_dict: 用于保存递归中获取的所有路由 - :return: - """ - for item in urlpatterns: - if isinstance(item, URLPattern): # 已经是根网址,获取name及url写入url_ordered_dict - if not item.name: # 没有别名(name)的路由地址直接跳过 - continue - # 拼接路由别名(包含分发下来的namespace;如 "rbac:menu_list") - if pre_namespace: - name = "%s:%s" % (pre_namespace, item.name) - else: - name = item.name - # 拼接路由地址URl(包含分发下来的上层路由;如 "/rbac/menu/list") - url = pre_url + item.pattern.regex.pattern # 此时拼接的路由包含起止符号,如:/^rbac/^menu/list/$ - url = url.replace("^", "").replace("$", "") # 删除起止符:/rbac/menu/list/ - # 排除一些特定的路由URL - if check_url_exclude(url): # 调用check_url_exclude函数定向排除部分URL - continue - url_ordered_dict[name] = {"name": name, "url": url} - - elif isinstance(item, URLResolver): # 路由分发,递归操作 - if pre_namespace: # 上次循环(上一层)分发是否包含namespace - if item.namespace: # 本次循环(当前层)是否包含namespace - namespace = "%s:%s" % (pre_namespace, item.namespace,) # 上层、当前层都包含直接拼接两层的namespace - else: - namespace = pre_namespace # 当前层分发不包含namespace,直接用上一层的 - else: - if item.namespace: - namespace = item.namespace # 上一层分发不包含namespace,直接使用当前层的 - else: - namespace = None # 上一层、当前层都没有,直接定义层none - recursion_urls(namespace, pre_url + item.pattern.regex.pattern, item.url_patterns, - url_ordered_dict) # 递归继续执行 - - -def get_all_url_dict(): - """ - 获取项目所有路由 - :return: - """ - url_ordered_dict = OrderedDict() # 包含本项目所有权限URl的有序字典 - md = import_string(settings.ROOT_URLCONF) # 配置文件内的 ROOT_URLCONF 为本项目根路由urls.py 的路径(字符串),使用 import_string 用字符串加载模块 - recursion_urls(None, "/", md.urlpatterns, - url_ordered_dict) # 调用 recursion_urls 函数获取所有路由字典,根路径下没有namespace 定义为 None;没有url前缀 定义为 / - return url_ordered_dict - - -if __name__ == '__main__': - - dict = get_all_url_dict() +# 写的什么东西?我不写 test 还发现不了,django 的 test 是你这样写的? +# import re # 正则 +# from collections import OrderedDict # 创建有序字典 +# +# from django.conf import settings # 导入setting配置文件(获取项目根路径urls.py) +# from django.urls import URLPattern, URLResolver # Django自定义的类,可以判断当前URL是否为根路径(URLPattern)或继续向下分发(URLResolver ) +# from django.utils.module_loading import import_string # 字符串导入模块 +# +# +# def check_url_exclude(url): +# """ +# 排除一些特定的URL +# :param url: 待检验的URL +# :return: +# """ +# for regex in settings.AUTO_DISCOVER_EXCLUDE: # 将要定向排除的URL(可包含正则) 按照列表的形式写入配置文件 settings.AUTO_DISCOVER_EXCLUDE 下 +# if re.match(regex, url): +# return True +# +# +# """ +# 示例: +# AUTO_DISCOVER_EXCLUDE = [ +# '/admin/.*', +# '/login/', +# '/logout/', +# '/index/', +# ] +# """ +# +# +# def recursion_urls(pre_namespace, pre_url, urlpatterns, url_ordered_dict): +# """ +# 递归的去获取URL +# :param pre_namespace: namespace前缀,用于拼接name +# :param pre_url: url前缀,用于拼接url +# :param urlpatterns: 路由关系列表 +# :param url_ordered_dict: 用于保存递归中获取的所有路由 +# :return: +# """ +# for item in urlpatterns: +# if isinstance(item, URLPattern): # 已经是根网址,获取name及url写入url_ordered_dict +# if not item.name: # 没有别名(name)的路由地址直接跳过 +# continue +# # 拼接路由别名(包含分发下来的namespace;如 "rbac:menu_list") +# if pre_namespace: +# name = "%s:%s" % (pre_namespace, item.name) +# else: +# name = item.name +# # 拼接路由地址URl(包含分发下来的上层路由;如 "/rbac/menu/list") +# url = pre_url + item.pattern.regex.pattern # 此时拼接的路由包含起止符号,如:/^rbac/^menu/list/$ +# url = url.replace("^", "").replace("$", "") # 删除起止符:/rbac/menu/list/ +# # 排除一些特定的路由URL +# if check_url_exclude(url): # 调用check_url_exclude函数定向排除部分URL +# continue +# url_ordered_dict[name] = {"name": name, "url": url} +# +# elif isinstance(item, URLResolver): # 路由分发,递归操作 +# if pre_namespace: # 上次循环(上一层)分发是否包含namespace +# if item.namespace: # 本次循环(当前层)是否包含namespace +# namespace = "%s:%s" % (pre_namespace, item.namespace,) # 上层、当前层都包含直接拼接两层的namespace +# else: +# namespace = pre_namespace # 当前层分发不包含namespace,直接用上一层的 +# else: +# if item.namespace: +# namespace = item.namespace # 上一层分发不包含namespace,直接使用当前层的 +# else: +# namespace = None # 上一层、当前层都没有,直接定义层none +# recursion_urls(namespace, pre_url + item.pattern.regex.pattern, item.url_patterns, +# url_ordered_dict) # 递归继续执行 +# +# +# def get_all_url_dict(): +# """ +# 获取项目所有路由 +# :return: +# """ +# url_ordered_dict = OrderedDict() # 包含本项目所有权限URl的有序字典 +# md = import_string(settings.ROOT_URLCONF) # 配置文件内的 ROOT_URLCONF 为本项目根路由urls.py 的路径(字符串),使用 import_string 用字符串加载模块 +# recursion_urls(None, "/", md.urlpatterns, +# url_ordered_dict) # 调用 recursion_urls 函数获取所有路由字典,根路径下没有namespace 定义为 None;没有url前缀 定义为 / +# return url_ordered_dict +# +# +# if __name__ == '__main__': +# +# dict = get_all_url_dict() + +from django.test import TestCase + +from users.models import UserNode + + +class Neo4jUserRelTestCase(TestCase): + def setUp(self) -> None: + Alice = self.create_node(username='Alice') + Bob = self.create_node(username='Bob') + Charlie = self.create_node(username='Charlie') + David = self.create_node(username='David') + + self.a_follow_b(Alice, Bob) + self.a_follow_b(Alice, Charlie) + self.a_follow_b(Alice, David) + + self.a_follow_b(Bob, Alice) + self.a_follow_b(Bob, David) + + def test_relationship(self): + """ + Alice follows Bob, Charlie, David. + Bob follows Alice, David. + + Alice's fans are: Bob. + Bob's fans are: Alice. + Charlie's fans are: Alice. + David's fans are: Alice, Bob. + + """ + alice = UserNode.nodes.get_or_none(username='Alice', user_id=1) + bob = UserNode.nodes.get_or_none(username='Bob', user_id=2) + charlie = UserNode.nodes.get_or_none(username='Charlie', user_id=3) + david = UserNode.nodes.get_or_none(username='David', user_id=4) + + alice_follow_list = alice.follows.all() + bob_follow_list = bob.follows.all() + charlie_follow_list = charlie.follows.all() + david_follow_list = david.follows.all() + + self.assertEqual(alice_follow_list.__len__(), 3) + self.assertEqual(bob_follow_list.__len__(), 2) + self.assertEqual(charlie_follow_list.__len__(), 0) + self.assertEqual(david_follow_list.__len__(), 0) + + alice_fan_list = alice.fans.all() + bob_fan_list = bob.fans.all() + charlie_fan_list = charlie.fans.all() + david_fan_list = david.fans.all() + + self.assertEqual(alice_fan_list.__len__(), 1) + self.assertEqual(bob_fan_list.__len__(), 1) + self.assertEqual(charlie_fan_list.__len__(), 1) + self.assertEqual(david_fan_list.__len__(), 2) + + @staticmethod + def a_follow_b(A: UserNode, B: UserNode) -> None: + """ + A follow B. First add a FollowRel edge from A to B, + then add a FanRel edge from B to A. + """ + A.follows.connect(B) + B.fans.connect(A) + + @staticmethod + def create_node(username: str) -> UserNode: + return UserNode(username=username).save() diff --git a/users/urls.py b/users/urls.py index 049c0075da7dcc3794a8291904c02079c3363a99..c0dddedd35ae25d076a372de7311505141631ec6 100644 --- a/users/urls.py +++ b/users/urls.py @@ -1,7 +1,7 @@ from django.urls import path, include from rest_framework import routers -from users.views import user_api, token_api, follow_api +from users.views import user_api, token_api user_api_router = routers.DefaultRouter() @@ -13,17 +13,13 @@ user_api_router.register( user_api_router.register( 'titles', user_api.UserTitleViewSet, - basename='titles') + basename='titles' +) user_api_router.register( 'tokens', token_api.TokenViewSet, basename='tokens' ) -user_api_router.register( - 'follows', - follow_api.FanFollowViewSet, - basename='follows' -) urlpatterns = [ path('', include((user_api_router.urls, 'users'), namespace='users')) diff --git a/users/views/follow_api.py b/users/views/follow_api.py deleted file mode 100644 index f038603765b342b62087c3b42378ca4f38556b59..0000000000000000000000000000000000000000 --- a/users/views/follow_api.py +++ /dev/null @@ -1,17 +0,0 @@ -from django_filters.rest_framework import DjangoFilterBackend -from rest_framework import permissions -from rest_framework.decorators import permission_classes -from rest_framework.filters import OrderingFilter, SearchFilter - -from backend import helper -from users.models import UserInfo -from users.serializers import UserFansSerializer - - -@permission_classes([permissions.IsAuthenticated]) -class FanFollowViewSet(helper.MyModelViewSet): - serializer_class = UserFansSerializer - queryset = UserInfo.objects.all() - - filter_backends = [DjangoFilterBackend, OrderingFilter, SearchFilter] - # TODO: ordering & filtering with `fans_count` field diff --git a/users/views/user_api.py b/users/views/user_api.py index 833f9859ec3dd98d047e020255f84d9fd65ca8c0..ce98bbca3cdbd41b0e60c9b073f67336c124f753 100644 --- a/users/views/user_api.py +++ b/users/views/user_api.py @@ -8,12 +8,13 @@ from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet +from backend.helper import APIResponse from users.filters.filters import UserFilter -from users.models import UserInfo, UserTitle +from users.models import UserInfo, UserTitle, UserNode from users.serializers import ( UserTitleSerializer, UserProfileSerializer, - ResetPasswordSerializer + ResetPasswordSerializer, FanFollowDetailSerializer ) from users.utils import get_user_by_email_or_phone @@ -75,6 +76,63 @@ class UserInfoViewSet(ModelViewSet): from utils.upload import upload_avatar upload_avatar(request) + @action(methods=['GET'], detail=True, + permission_classes=[permissions.AllowAny], + serializer_class=[FanFollowDetailSerializer]) + def fans(self, request, pk): + """ + 用户的粉丝列表 + """ + node = UserNode.nodes.get(user_id=pk) + fans = node.fans.all() + + fan_list = [] + for f in fans: + one = { + 'id': f.id, + 'uid': f.uid, + 'user_id': f.user_id, + 'username': f.username, + 'avatar': f.avatar + } + fan_list.append(one) + + return APIResponse( + success=True, + code=200, + msg='success', + data=fan_list, + status=status.HTTP_200_OK + ) + + @action(methods=['GET'], detail=True, + permission_classes=[permissions.AllowAny], + serializer_class=[FanFollowDetailSerializer]) + def follows(self, request, pk): + """ + 用户的关注列表 + """ + node = UserNode.nodes.get(user_id=pk) + follows = node.follows.all() + + follow_list = [] + for f in follows: + one = { + 'id': f.id, + 'uid': f.uid, + 'user_id': f.user_id, + 'username': f.username, + 'avatar': f.avatar + } + follow_list.append(one) + + return APIResponse( + success=True, + code=200, + msg='success', + data=follow_list, + status=status.HTTP_200_OK + ) class UserTitleViewSet(ModelViewSet): """ diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391