You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
7.6 KiB

from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, Field
from typing import Optional
from jose import JWTError, jwt
from datetime import datetime, timedelta
from pymongo import MongoClient
from bson.objectid import ObjectId
# Best workaround found for _id typed as ObjectId (creating Exception bcause JSON doesn't support custom types countrary to BSON, used by Mongo)
# also allows to create DTOs at the time, but not at it's best (project structure is chaotic FTM :s)
from serializers import friends_serialize, pins_serialize, users_serialize
import hashlib
# Constants for JWT
SECRET_KEY = "_2YfT44$xF.Tg_xI63UH3D7:N+>pZN2';j%>7H@?e0:Xor'pV[" # temporary of course :)
ALGORITHM = "HS256" # TODO: check if broken (don't believe)
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # TODO: check what to add here / maybe need to evaluate criticity of that?
# Database setup
client = MongoClient("mongodb://localhost:27017/", username="mongoadmin", password="secret")
db = client["memorymap"]
# FastAPI app instance
app = FastAPI()
# OAuth2 scheme
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
# Pydantic models
class User(BaseModel):
uid: str = Field(..., alias="_id")
username: str
password: str
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Optional[str] = None
class Pin(BaseModel):
title: str
description: str
class Friend(BaseModel):
user_id: str
# Collections
users_collection = db["users"]
pins_collection = db["pins"]
friends_collection = db["friends"]
# Utility functions
def verify_password(plain_password, hashed_password):
return hashlib.sha256(plain_password.encode()).hexdigest() == hashed_password
def get_password_hash(password):
return hashlib.sha256(password.encode()).hexdigest()
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = users_collection.find_one({"username": token_data.username})
if user is None:
raise credentials_exception
return user
# Routes - TODO: find workaround to display 401/409/... HTTP error codes in openapi.json
@app.post("/register", response_model=Token)
async def register(user: User):
user_exists = users_collection.find_one({"username": user.username})
if user_exists:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Username already used"
)
hashed_password = get_password_hash(user.password)
users_collection.insert_one({"username": user.username, "password": hashed_password})
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(data={"sub": user.username}, expires_delta=access_token_expires)
return {"access_token": access_token, "token_type": "bearer"}
@app.post("/login", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user = users_collection.find_one({"username": form_data.username})
if not user or not verify_password(form_data.password, user["password"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(data={"sub": form_data.username}, expires_delta=access_token_expires)
return {"access_token": access_token, "token_type": "bearer"}
@app.get("/logout")
async def logout(current_user: User = Depends(get_current_user)):
# TODO: find usecase / what to do ??
return {"message": "Logged out"}
@app.get("/pin/{id}")
async def get_pin(id: str, current_user: User = Depends(get_current_user)):
pin = pins_collection.find_one({"_id": ObjectId(id)})
if pin is None:
raise HTTPException(status_code=404, detail="Pin not found")
return pin
@app.patch("/pin/{id}")
async def update_pin(id: str, pin: Pin, current_user: User = Depends(get_current_user)):
result = pins_collection.update_one({"_id": ObjectId(id)}, {"$set": pin.model_dump()})
if result.matched_count == 0:
raise HTTPException(status_code=404, detail="Pin not found")
return {"message": "Pin updated"}
@app.post("/pin/add")
async def add_pin(pin: Pin, current_user: User = Depends(get_current_user)):
pin_id = pins_collection.insert_one(pin.model_dump()).inserted_id
return {"id": str(pin_id)}
@app.get("/pins")
async def list_pins(current_user: User = Depends(get_current_user)):
pins = pins_serialize(pins_collection.find().to_list())
return pins
@app.get("/friend/{id}")
async def get_friend(id: str, current_user: User = Depends(get_current_user)):
friend = friends_collection.find_one({"_id": ObjectId(id)})
if friend is None:
raise HTTPException(status_code=404, detail="Friend not found")
return friend
@app.post("/friend/add")
async def add_friend(friend: Friend, current_user: User = Depends(get_current_user)):
# TODO: test if exists
friend_id = friends_collection.insert_one(friend.model_dump()).inserted_id
return {"id": str(friend_id)}
@app.delete("/friend/{id}/delete")
async def delete_friend(id: str, current_user: User = Depends(get_current_user)):
result = friends_collection.delete_one({"_id": ObjectId(id)})
if result.deleted_count == 0:
raise HTTPException(status_code=404, detail="Friend not found")
return {"message": "Friend deleted"}
@app.patch("/friend/{id}/accept")
async def accept_friend(id: str, current_user: User = Depends(get_current_user)):
result = friends_collection.update_one({"_id": ObjectId(id)}, {"$set": {"status": "accepted"}})
if result.matched_count == 0:
raise HTTPException(status_code=404, detail="Friend not found")
return {"message": "Friend request accepted"}
@app.post("/friend/{id}/deny")
async def deny_friend(id: str, current_user: User = Depends(get_current_user)):
result = friends_collection.update_one({"_id": ObjectId(id)}, {"$set": {"status": "denied"}})
if result.matched_count == 0:
raise HTTPException(status_code=404, detail="Friend not found")
return {"message": "Friend request denied"}
@app.get("/friends")
async def list_friends(current_user: User = Depends(get_current_user)):
friends = friends_serialize(friends_collection.find().to_list())
return friends
@app.get("/users")
async def search_users(name: str, current_user: User = Depends(get_current_user)):
# TODO: /!\ pymongo.errors.OperationFailure if regex is poop
users = users_serialize(users_collection.find({"username": {"$regex": name, "$options": "i"}}).to_list())
return users