feat: Use aiohttp instead of waiting for blocking calls (#227)

* Use native async call instead of converting blocking calls

* remove nullable declarations

* fixs

* Fix star expression

* fix gather again

* remove unused private function

* Fix naming conflict

* Add the deleted function back. Disable the warning instead.

* remove trailing space

* handle wrong mime type from cloud

* Fix request header

* fix missing await
This commit is contained in:
Feng Wang 2024-12-20 17:34:34 +08:00 committed by GitHub
parent a879ae2cdf
commit aacb794e1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 77 additions and 111 deletions

View File

@ -23,7 +23,8 @@
"paho-mqtt<=2.0.0",
"numpy",
"cryptography",
"psutil"
"psutil",
"aiohttp[speedups]"
],
"version": "v0.1.2",
"zeroconf": [

View File

@ -51,10 +51,9 @@ import json
import logging
import re
import time
from functools import partial
from typing import Optional
from urllib.parse import urlencode
import requests
import aiohttp
# pylint: disable=relative-beyond-top-level
from .common import calc_group_id
@ -71,8 +70,9 @@ TOKEN_EXPIRES_TS_RATIO = 0.7
class MIoTOauthClient:
"""oauth agent url, default: product env."""
_main_loop: asyncio.AbstractEventLoop = None
_oauth_host: str = None
_main_loop: asyncio.AbstractEventLoop
_session: aiohttp.ClientSession
_oauth_host: str
_client_id: int
_redirect_url: str
@ -94,9 +94,10 @@ class MIoTOauthClient:
self._oauth_host = DEFAULT_OAUTH2_API_HOST
else:
self._oauth_host = f'{cloud_server}.{DEFAULT_OAUTH2_API_HOST}'
self._session = aiohttp.ClientSession()
async def __call_async(self, func):
return await self._main_loop.run_in_executor(executor=None, func=func)
def __del__(self):
self._session.close()
def set_redirect_url(self, redirect_url: str) -> None:
if not isinstance(redirect_url, str) or redirect_url.strip() == '':
@ -140,21 +141,22 @@ class MIoTOauthClient:
return f'{OAUTH2_AUTH_URL}?{encoded_params}'
def _get_token(self, data) -> dict:
http_res = requests.get(
async def __get_token_async(self, data) -> dict:
http_res = await self._session.get(
url=f'https://{self._oauth_host}/app/v2/ha/oauth/get_token',
params={'data': json.dumps(data)},
headers={'content-type': 'application/x-www-form-urlencoded'},
timeout=MIHOME_HTTP_API_TIMEOUT
)
if http_res.status_code == 401:
if http_res.status == 401:
raise MIoTOauthError(
'unauthorized(401)', MIoTErrorCode.CODE_OAUTH_UNAUTHORIZED)
if http_res.status_code != 200:
if http_res.status != 200:
raise MIoTOauthError(
f'invalid http status code, {http_res.status_code}')
f'invalid http status code, {http_res.status}')
res_obj = http_res.json()
res_str = await http_res.text()
res_obj = json.loads(res_str)
if (
not res_obj
or res_obj.get('code', None) != 0
@ -172,7 +174,7 @@ class MIoTOauthClient:
(res_obj['result'].get('expires_in', 0)*TOKEN_EXPIRES_TS_RATIO))
}
def get_access_token(self, code: str) -> dict:
async def get_access_token_async(self, code: str) -> dict:
"""get access token by authorization code
Args:
@ -184,16 +186,13 @@ class MIoTOauthClient:
if not isinstance(code, str):
raise MIoTOauthError('invalid code')
return self._get_token(data={
return await self.__get_token_async(data={
'client_id': self._client_id,
'redirect_uri': self._redirect_url,
'code': code,
})
async def get_access_token_async(self, code: str) -> dict:
return await self.__call_async(partial(self.get_access_token, code))
def refresh_access_token(self, refresh_token: str) -> dict:
async def refresh_access_token_async(self, refresh_token: str) -> dict:
"""get access token by refresh token.
Args:
@ -205,16 +204,12 @@ class MIoTOauthClient:
if not isinstance(refresh_token, str):
raise MIoTOauthError('invalid refresh_token')
return self._get_token(data={
return await self._get_token_async(data={
'client_id': self._client_id,
'redirect_uri': self._redirect_url,
'refresh_token': refresh_token,
})
async def refresh_access_token_async(self, refresh_token: str) -> dict:
return await self.__call_async(
partial(self.refresh_access_token, refresh_token))
class MIoTHttpClient:
"""MIoT http client."""
@ -222,6 +217,7 @@ class MIoTHttpClient:
GET_PROP_AGGREGATE_INTERVAL: float = 0.2
GET_PROP_MAX_REQ_COUNT = 150
_main_loop: asyncio.AbstractEventLoop
_session: aiohttp.ClientSession
_host: str
_base_url: str
_client_id: str
@ -254,10 +250,10 @@ class MIoTHttpClient:
cloud_server=cloud_server, client_id=client_id,
access_token=access_token)
async def __call_async(self, func) -> any:
if self._main_loop is None:
raise MIoTHttpError('miot http, un-support async methods')
return await self._main_loop.run_in_executor(executor=None, func=func)
self._session = aiohttp.ClientSession()
def __del__(self):
self._session.close()
def update_http_header(
self, cloud_server: Optional[str] = None,
@ -276,36 +272,35 @@ class MIoTHttpClient:
self._access_token = access_token
@property
def __api_session(self) -> requests.Session:
session = requests.Session()
session.headers.update({
def __api_request_headers(self) -> dict:
return {
'Host': self._host,
'X-Client-BizId': 'haapi',
'Content-Type': 'application/json',
'Authorization': f'Bearer{self._access_token}',
'X-Client-AppId': self._client_id,
})
return session
}
def mihome_api_get(
# pylint: disable=unused-private-member
async def __mihome_api_get_async(
self, url_path: str, params: dict,
timeout: int = MIHOME_HTTP_API_TIMEOUT
) -> dict:
http_res = None
with self.__api_session as session:
http_res = session.get(
url=f'{self._base_url}{url_path}',
params=params,
timeout=timeout)
if http_res.status_code == 401:
http_res = await self._session.get(
url=f'{self._base_url}{url_path}',
params=params,
headers=self.__api_request_headers,
timeout=timeout)
if http_res.status == 401:
raise MIoTHttpError(
'mihome api get failed, unauthorized(401)',
MIoTErrorCode.CODE_HTTP_INVALID_ACCESS_TOKEN)
if http_res.status_code != 200:
if http_res.status != 200:
raise MIoTHttpError(
f'mihome api get failed, {http_res.status_code}, '
f'mihome api get failed, {http_res.status}, '
f'{url_path}, {params}')
res_obj: dict = http_res.json()
res_str = await http_res.text()
res_obj: dict = json.loads(res_str)
if res_obj.get('code', None) != 0:
raise MIoTHttpError(
f'invalid response code, {res_obj.get("code",None)}, '
@ -315,28 +310,25 @@ class MIoTHttpClient:
self._base_url, url_path, params, res_obj)
return res_obj
def mihome_api_post(
async def __mihome_api_post_async(
self, url_path: str, data: dict,
timeout: int = MIHOME_HTTP_API_TIMEOUT
) -> dict:
encoded_data = None
if data:
encoded_data = json.dumps(data).encode('utf-8')
http_res = None
with self.__api_session as session:
http_res = session.post(
url=f'{self._base_url}{url_path}',
data=encoded_data,
timeout=timeout)
if http_res.status_code == 401:
http_res = await self._session.post(
url=f'{self._base_url}{url_path}',
json=data,
headers=self.__api_request_headers,
timeout=timeout)
if http_res.status == 401:
raise MIoTHttpError(
'mihome api get failed, unauthorized(401)',
MIoTErrorCode.CODE_HTTP_INVALID_ACCESS_TOKEN)
if http_res.status_code != 200:
if http_res.status != 200:
raise MIoTHttpError(
f'mihome api post failed, {http_res.status_code}, '
f'mihome api post failed, {http_res.status}, '
f'{url_path}, {data}')
res_obj: dict = http_res.json()
res_str = await http_res.text()
res_obj: dict = json.loads(res_str)
if res_obj.get('code', None) != 0:
raise MIoTHttpError(
f'invalid response code, {res_obj.get("code",None)}, '
@ -346,8 +338,8 @@ class MIoTHttpClient:
self._base_url, url_path, data, res_obj)
return res_obj
def get_user_info(self) -> dict:
http_res = requests.get(
async def get_user_info_async(self) -> dict:
http_res = await self._session.get(
url='https://open.account.xiaomi.com/user/profile',
params={'clientId': self._client_id,
'token': self._access_token},
@ -355,7 +347,8 @@ class MIoTHttpClient:
timeout=MIHOME_HTTP_API_TIMEOUT
)
res_obj = http_res.json()
res_str = await http_res.text()
res_obj = json.loads(res_str)
if (
not res_obj
or res_obj.get('code', None) != 0
@ -366,14 +359,11 @@ class MIoTHttpClient:
return res_obj['data']
async def get_user_info_async(self) -> dict:
return await self.__call_async(partial(self.get_user_info))
def get_central_cert(self, csr: str) -> Optional[str]:
async def get_central_cert_async(self, csr: str) -> Optional[str]:
if not isinstance(csr, str):
raise MIoTHttpError('invalid params')
res_obj: dict = self.mihome_api_post(
res_obj: dict = await self.__mihome_api_post_async(
url_path='/app/v2/ha/oauth/get_central_crt',
data={
'csr': str(base64.b64encode(csr.encode('utf-8')), 'utf-8')
@ -387,11 +377,8 @@ class MIoTHttpClient:
return cert
async def get_central_cert_async(self, csr: str) -> Optional[str]:
return await self.__call_async(partial(self.get_central_cert, csr))
def __get_dev_room_page(self, max_id: str = None) -> dict:
res_obj = self.mihome_api_post(
async def __get_dev_room_page_async(self, max_id: str = None) -> dict:
res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/homeroom/get_dev_room_page',
data={
'start_id': max_id,
@ -419,7 +406,7 @@ class MIoTHttpClient:
res_obj['result'].get('has_more', False)
and isinstance(res_obj['result'].get('max_id', None), str)
):
next_list = self.__get_dev_room_page(
next_list = await self.__get_dev_room_page_async(
max_id=res_obj['result']['max_id'])
for home_id, info in next_list.items():
home_list.setdefault(home_id, {'dids': [], 'room_info': {}})
@ -432,8 +419,8 @@ class MIoTHttpClient:
return home_list
def get_homeinfos(self) -> dict:
res_obj = self.mihome_api_post(
async def get_homeinfos_async(self) -> dict:
res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/homeroom/gethome',
data={
'limit': 150,
@ -485,7 +472,7 @@ class MIoTHttpClient:
res_obj['result'].get('has_more', False)
and isinstance(res_obj['result'].get('max_id', None), str)
):
more_list = self.__get_dev_room_page(
more_list = await self.__get_dev_room_page_async(
max_id=res_obj['result']['max_id'])
for home_id, info in more_list.items():
if home_id not in home_infos['homelist']:
@ -507,16 +494,10 @@ class MIoTHttpClient:
'share_home_list': home_infos.get('share_home_list', [])
}
async def get_homeinfos_async(self) -> dict:
return await self.__call_async(self.get_homeinfos)
def get_uid(self) -> str:
return self.get_homeinfos().get('uid', None)
async def get_uid_async(self) -> str:
return (await self.get_homeinfos_async()).get('uid', None)
def __get_device_list_page(
async def __get_device_list_page_async(
self, dids: list[str], start_did: str = None
) -> dict[str, dict]:
req_data: dict = {
@ -527,7 +508,7 @@ class MIoTHttpClient:
if start_did:
req_data['start_did'] = start_did
device_infos: dict = {}
res_obj = self.mihome_api_post(
res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/home/device_list_page',
data=req_data
)
@ -578,7 +559,7 @@ class MIoTHttpClient:
next_start_did = res_obj.get('next_start_did', None)
if res_obj.get('has_more', False) and next_start_did:
device_infos.update(self.__get_device_list_page(
device_infos.update(await self.__get_device_list_page_async(
dids=dids, start_did=next_start_did))
return device_infos
@ -587,8 +568,7 @@ class MIoTHttpClient:
self, dids: list[str]
) -> dict[str, dict]:
results: list[dict[str, dict]] = await asyncio.gather(
*[self.__call_async(
partial(self.__get_device_list_page, dids[index:index+150]))
*[self.__get_device_list_page_async(dids[index:index+150])
for index in range(0, len(dids), 150)])
devices = {}
for result in results:
@ -665,12 +645,12 @@ class MIoTHttpClient:
'devices': devices
}
def get_props(self, params: list) -> list:
async def get_props_async(self, params: list) -> list:
"""
params = [{"did": "xxxx", "siid": 2, "piid": 1},
{"did": "xxxxxx", "siid": 2, "piid": 2}]
"""
res_obj = self.mihome_api_post(
res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/miotspec/prop/get',
data={
'datasource': 1,
@ -681,11 +661,9 @@ class MIoTHttpClient:
raise MIoTHttpError('invalid response result')
return res_obj['result']
async def get_props_async(self, params: list) -> list:
return await self.__call_async(partial(self.get_props, params))
def get_prop(self, did: str, siid: int, piid: int) -> any:
results = self.get_props(
async def __get_prop_async(self, did: str, siid: int, piid: int) -> any:
results = await self.get_props_async(
params=[{'did': did, 'siid': siid, 'piid': piid}])
if not results:
return None
@ -711,7 +689,7 @@ class MIoTHttpClient:
if not props_buffer:
_LOGGER.error('get prop error, empty request list')
return False
results = await self.__call_async(partial(self.get_props, props_buffer))
results = await self.get_props_async(props_buffer)
for result in results:
if not all(
@ -747,8 +725,7 @@ class MIoTHttpClient:
self, did: str, siid: int, piid: int, immediately: bool = False
) -> any:
if immediately:
return await self.__call_async(
partial(self.get_prop, did, siid, piid))
return await self.__get_prop_async(did, siid, piid)
key: str = f'{did}.{siid}.{piid}'
prop_obj = self._get_prop_list.get(key, None)
if prop_obj:
@ -766,11 +743,11 @@ class MIoTHttpClient:
return await fut
def set_prop(self, params: list) -> list:
async def set_prop_async(self, params: list) -> list:
"""
params = [{"did": "xxxx", "siid": 2, "piid": 1, "value": False}]
"""
res_obj = self.mihome_api_post(
res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/miotspec/prop/set',
data={
'params': params
@ -782,20 +759,14 @@ class MIoTHttpClient:
return res_obj['result']
async def set_prop_async(self, params: list) -> list:
"""
params = [{"did": "xxxx", "siid": 2, "piid": 1, "value": False}]
"""
return await self.__call_async(partial(self.set_prop, params))
def action(
async def action_async(
self, did: str, siid: int, aiid: int, in_list: list[dict]
) -> dict:
"""
params = {"did": "xxxx", "siid": 2, "aiid": 1, "in": []}
"""
# NOTICE: Non-standard action param
res_obj = self.mihome_api_post(
res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/miotspec/action',
data={
'params': {
@ -810,9 +781,3 @@ class MIoTHttpClient:
raise MIoTHttpError('invalid response result')
return res_obj['result']
async def action_async(
self, did: str, siid: int, aiid: int, in_list: list[dict]
) -> dict:
return await self.__call_async(
partial(self.action, did, siid, aiid, in_list))