diff --git a/yt_dlp/extractor/sheeta.py b/yt_dlp/extractor/sheeta.py index 0cac2f5449..ce0e08c4af 100644 --- a/yt_dlp/extractor/sheeta.py +++ b/yt_dlp/extractor/sheeta.py @@ -1,4 +1,5 @@ import base64 +import enum import functools import hashlib import json @@ -25,6 +26,11 @@ from ..utils.traversal import traverse_obj +class AuthType(enum.Enum): + AUTH0 = 'auth0' + NICONICO = 'niconico' + + class AuthManager: _AUTH_INFO_CACHE = {} _AUTH0_BASE64_TRANS = str.maketrans({ @@ -79,11 +85,17 @@ def _get_auth_token(self): return None def _refresh_token(self): - if not (refresh_func := self._auth_info.get('refresh_func')): + if not (refresh_func_params := self._auth_info.get('refresh_func_params')): return False + if self._auth_info.get('auth_type') == AuthType.AUTH0: + refresh_func_params['data'] = urlencode_postdata({ + **refresh_func_params['data'], + 'refresh_token': self._auth_info.get('refresh_token'), + }) + res = self._ie._download_json( - **refresh_func(self._auth_info), expected_status=(400, 403, 404), + **refresh_func_params, expected_status=(400, 403, 404), note='Refreshing token', errnote='Unable to refresh token') if error := traverse_obj( res, ('error', 'message', {lambda x: base64.b64decode(x).decode()}), ('error', 'message')): @@ -127,7 +139,7 @@ def _login(self): return self._auth0_login() def _niconico_sns_login(self, redirect_url, refresh_url): - self._auth_info = {'login_method': 'any'} + self._auth_info = {'login_method': 'any', 'auth_type': AuthType.NICONICO} mail_tel, password = self._ie._get_login_info() if not mail_tel: return @@ -153,13 +165,12 @@ def _niconico_sns_login(self, redirect_url, refresh_url): return self._auth_info = { - 'refresh_func': lambda data: { - 'url_or_request': data['refresh_url'], + 'refresh_func_params': { + 'url_or_request': refresh_url, 'video_id': None, - 'headers': {'Authorization': data['auth_token']}, + 'headers': {'Authorization': auth_token}, 'data': b'', }, - 'refresh_url': refresh_url, 'auth_token': auth_token, } @@ -235,7 +246,7 @@ def _niconico_login(self, mail_tel, password): return True def _auth0_login(self): - self._auth_info = {'login_method': 'password'} + self._auth_info = {'login_method': 'password', 'auth_type': AuthType.AUTH0} username, password = self._ie._get_login_info() if not username: return @@ -260,17 +271,15 @@ def _auth0_login(self): 'version': '2.0.6', }).encode()).decode() - self._auth_info = {'refresh_func': lambda data: { + self._auth_info = {'refresh_func_params': { 'url_or_request': token_url, 'video_id': None, 'headers': {'Auth0-Client': auth0_client}, - 'data': urlencode_postdata({ + 'data': { 'client_id': auth0_web_client_id, 'grant_type': 'refresh_token', - 'refresh_token': data['refresh_token'], 'redirect_uri': redirect_url, - }), - }} + }}} def random_str(): return ''.join(random.choices(string.digits + string.ascii_letters, k=43))