You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.1 KiB

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from joserfc import jwt
from joserfc.errors import JoseError
from joserfc.jwk import OctKey
from datetime import datetime, timedelta
from app.models import User
from app.models.token_data import TokenData
from app.serializers import user_serialize
import pymongo
import app.config as config
# Database setup
client = pymongo.MongoClient(config.MONGODB_URL, username=config.MONGODB_USERNAME, password=config.MONGODB_PASSWORD)
db = client[config.MONGODB_DATABASE]
users_collection = db["users"]
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=config.TOKEN_URL)
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, OctKey.import_key(config.SECRET_KEY))
username: str = payload.claims["sub"]
expire_date = payload.claims["exp"]
if username is None or int(datetime.now().timestamp()) > expire_date:
raise credentials_exception
token_data = TokenData(username=username)
except JoseError:
raise credentials_exception
user = users_collection.find_one({"username": token_data.username})
if user is None:
raise credentials_exception
return user_serialize(user)
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.now() + expires_delta
to_encode.update({"exp": expire})
header = {"alg": config.ALGORITHM}
encoded_jwt = jwt.encode(header, to_encode, OctKey.import_key(config.SECRET_KEY))
return encoded_jwt
# Exceptions
def friend_not_found():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Friend not found"
)
def objectid_misformatted():
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="The ObjectID is misformatted"
)