feat: Use aiohttp instead of waiting for blocking calls (#227)

* Use native async call instead of converting blocking calls

* remove nullable declarations

* fixs

* Fix star expression

* fix gather again

* remove unused private function

* Fix naming conflict

* Add the deleted function back. Disable the warning instead.

* remove trailing space

* handle wrong mime type from cloud

* Fix request header

* fix missing await
This commit is contained in:
Feng Wang 2024-12-20 17:34:34 +08:00 committed by GitHub
parent a879ae2cdf
commit aacb794e1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 77 additions and 111 deletions

View File

@ -23,7 +23,8 @@
"paho-mqtt<=2.0.0", "paho-mqtt<=2.0.0",
"numpy", "numpy",
"cryptography", "cryptography",
"psutil" "psutil",
"aiohttp[speedups]"
], ],
"version": "v0.1.2", "version": "v0.1.2",
"zeroconf": [ "zeroconf": [

View File

@ -51,10 +51,9 @@ import json
import logging import logging
import re import re
import time import time
from functools import partial
from typing import Optional from typing import Optional
from urllib.parse import urlencode from urllib.parse import urlencode
import requests import aiohttp
# pylint: disable=relative-beyond-top-level # pylint: disable=relative-beyond-top-level
from .common import calc_group_id from .common import calc_group_id
@ -71,8 +70,9 @@ TOKEN_EXPIRES_TS_RATIO = 0.7
class MIoTOauthClient: class MIoTOauthClient:
"""oauth agent url, default: product env.""" """oauth agent url, default: product env."""
_main_loop: asyncio.AbstractEventLoop = None _main_loop: asyncio.AbstractEventLoop
_oauth_host: str = None _session: aiohttp.ClientSession
_oauth_host: str
_client_id: int _client_id: int
_redirect_url: str _redirect_url: str
@ -94,9 +94,10 @@ class MIoTOauthClient:
self._oauth_host = DEFAULT_OAUTH2_API_HOST self._oauth_host = DEFAULT_OAUTH2_API_HOST
else: else:
self._oauth_host = f'{cloud_server}.{DEFAULT_OAUTH2_API_HOST}' self._oauth_host = f'{cloud_server}.{DEFAULT_OAUTH2_API_HOST}'
self._session = aiohttp.ClientSession()
async def __call_async(self, func): def __del__(self):
return await self._main_loop.run_in_executor(executor=None, func=func) self._session.close()
def set_redirect_url(self, redirect_url: str) -> None: def set_redirect_url(self, redirect_url: str) -> None:
if not isinstance(redirect_url, str) or redirect_url.strip() == '': if not isinstance(redirect_url, str) or redirect_url.strip() == '':
@ -140,21 +141,22 @@ class MIoTOauthClient:
return f'{OAUTH2_AUTH_URL}?{encoded_params}' return f'{OAUTH2_AUTH_URL}?{encoded_params}'
def _get_token(self, data) -> dict: async def __get_token_async(self, data) -> dict:
http_res = requests.get( http_res = await self._session.get(
url=f'https://{self._oauth_host}/app/v2/ha/oauth/get_token', url=f'https://{self._oauth_host}/app/v2/ha/oauth/get_token',
params={'data': json.dumps(data)}, params={'data': json.dumps(data)},
headers={'content-type': 'application/x-www-form-urlencoded'}, headers={'content-type': 'application/x-www-form-urlencoded'},
timeout=MIHOME_HTTP_API_TIMEOUT timeout=MIHOME_HTTP_API_TIMEOUT
) )
if http_res.status_code == 401: if http_res.status == 401:
raise MIoTOauthError( raise MIoTOauthError(
'unauthorized(401)', MIoTErrorCode.CODE_OAUTH_UNAUTHORIZED) 'unauthorized(401)', MIoTErrorCode.CODE_OAUTH_UNAUTHORIZED)
if http_res.status_code != 200: if http_res.status != 200:
raise MIoTOauthError( raise MIoTOauthError(
f'invalid http status code, {http_res.status_code}') f'invalid http status code, {http_res.status}')
res_obj = http_res.json() res_str = await http_res.text()
res_obj = json.loads(res_str)
if ( if (
not res_obj not res_obj
or res_obj.get('code', None) != 0 or res_obj.get('code', None) != 0
@ -172,7 +174,7 @@ class MIoTOauthClient:
(res_obj['result'].get('expires_in', 0)*TOKEN_EXPIRES_TS_RATIO)) (res_obj['result'].get('expires_in', 0)*TOKEN_EXPIRES_TS_RATIO))
} }
def get_access_token(self, code: str) -> dict: async def get_access_token_async(self, code: str) -> dict:
"""get access token by authorization code """get access token by authorization code
Args: Args:
@ -184,16 +186,13 @@ class MIoTOauthClient:
if not isinstance(code, str): if not isinstance(code, str):
raise MIoTOauthError('invalid code') raise MIoTOauthError('invalid code')
return self._get_token(data={ return await self.__get_token_async(data={
'client_id': self._client_id, 'client_id': self._client_id,
'redirect_uri': self._redirect_url, 'redirect_uri': self._redirect_url,
'code': code, 'code': code,
}) })
async def get_access_token_async(self, code: str) -> dict: async def refresh_access_token_async(self, refresh_token: str) -> dict:
return await self.__call_async(partial(self.get_access_token, code))
def refresh_access_token(self, refresh_token: str) -> dict:
"""get access token by refresh token. """get access token by refresh token.
Args: Args:
@ -205,16 +204,12 @@ class MIoTOauthClient:
if not isinstance(refresh_token, str): if not isinstance(refresh_token, str):
raise MIoTOauthError('invalid refresh_token') raise MIoTOauthError('invalid refresh_token')
return self._get_token(data={ return await self._get_token_async(data={
'client_id': self._client_id, 'client_id': self._client_id,
'redirect_uri': self._redirect_url, 'redirect_uri': self._redirect_url,
'refresh_token': refresh_token, 'refresh_token': refresh_token,
}) })
async def refresh_access_token_async(self, refresh_token: str) -> dict:
return await self.__call_async(
partial(self.refresh_access_token, refresh_token))
class MIoTHttpClient: class MIoTHttpClient:
"""MIoT http client.""" """MIoT http client."""
@ -222,6 +217,7 @@ class MIoTHttpClient:
GET_PROP_AGGREGATE_INTERVAL: float = 0.2 GET_PROP_AGGREGATE_INTERVAL: float = 0.2
GET_PROP_MAX_REQ_COUNT = 150 GET_PROP_MAX_REQ_COUNT = 150
_main_loop: asyncio.AbstractEventLoop _main_loop: asyncio.AbstractEventLoop
_session: aiohttp.ClientSession
_host: str _host: str
_base_url: str _base_url: str
_client_id: str _client_id: str
@ -254,10 +250,10 @@ class MIoTHttpClient:
cloud_server=cloud_server, client_id=client_id, cloud_server=cloud_server, client_id=client_id,
access_token=access_token) access_token=access_token)
async def __call_async(self, func) -> any: self._session = aiohttp.ClientSession()
if self._main_loop is None:
raise MIoTHttpError('miot http, un-support async methods') def __del__(self):
return await self._main_loop.run_in_executor(executor=None, func=func) self._session.close()
def update_http_header( def update_http_header(
self, cloud_server: Optional[str] = None, self, cloud_server: Optional[str] = None,
@ -276,36 +272,35 @@ class MIoTHttpClient:
self._access_token = access_token self._access_token = access_token
@property @property
def __api_session(self) -> requests.Session: def __api_request_headers(self) -> dict:
session = requests.Session() return {
session.headers.update({
'Host': self._host, 'Host': self._host,
'X-Client-BizId': 'haapi', 'X-Client-BizId': 'haapi',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': f'Bearer{self._access_token}', 'Authorization': f'Bearer{self._access_token}',
'X-Client-AppId': self._client_id, 'X-Client-AppId': self._client_id,
}) }
return session
def mihome_api_get( # pylint: disable=unused-private-member
async def __mihome_api_get_async(
self, url_path: str, params: dict, self, url_path: str, params: dict,
timeout: int = MIHOME_HTTP_API_TIMEOUT timeout: int = MIHOME_HTTP_API_TIMEOUT
) -> dict: ) -> dict:
http_res = None http_res = await self._session.get(
with self.__api_session as session:
http_res = session.get(
url=f'{self._base_url}{url_path}', url=f'{self._base_url}{url_path}',
params=params, params=params,
headers=self.__api_request_headers,
timeout=timeout) timeout=timeout)
if http_res.status_code == 401: if http_res.status == 401:
raise MIoTHttpError( raise MIoTHttpError(
'mihome api get failed, unauthorized(401)', 'mihome api get failed, unauthorized(401)',
MIoTErrorCode.CODE_HTTP_INVALID_ACCESS_TOKEN) MIoTErrorCode.CODE_HTTP_INVALID_ACCESS_TOKEN)
if http_res.status_code != 200: if http_res.status != 200:
raise MIoTHttpError( raise MIoTHttpError(
f'mihome api get failed, {http_res.status_code}, ' f'mihome api get failed, {http_res.status}, '
f'{url_path}, {params}') f'{url_path}, {params}')
res_obj: dict = http_res.json() res_str = await http_res.text()
res_obj: dict = json.loads(res_str)
if res_obj.get('code', None) != 0: if res_obj.get('code', None) != 0:
raise MIoTHttpError( raise MIoTHttpError(
f'invalid response code, {res_obj.get("code",None)}, ' f'invalid response code, {res_obj.get("code",None)}, '
@ -315,28 +310,25 @@ class MIoTHttpClient:
self._base_url, url_path, params, res_obj) self._base_url, url_path, params, res_obj)
return res_obj return res_obj
def mihome_api_post( async def __mihome_api_post_async(
self, url_path: str, data: dict, self, url_path: str, data: dict,
timeout: int = MIHOME_HTTP_API_TIMEOUT timeout: int = MIHOME_HTTP_API_TIMEOUT
) -> dict: ) -> dict:
encoded_data = None http_res = await self._session.post(
if data:
encoded_data = json.dumps(data).encode('utf-8')
http_res = None
with self.__api_session as session:
http_res = session.post(
url=f'{self._base_url}{url_path}', url=f'{self._base_url}{url_path}',
data=encoded_data, json=data,
headers=self.__api_request_headers,
timeout=timeout) timeout=timeout)
if http_res.status_code == 401: if http_res.status == 401:
raise MIoTHttpError( raise MIoTHttpError(
'mihome api get failed, unauthorized(401)', 'mihome api get failed, unauthorized(401)',
MIoTErrorCode.CODE_HTTP_INVALID_ACCESS_TOKEN) MIoTErrorCode.CODE_HTTP_INVALID_ACCESS_TOKEN)
if http_res.status_code != 200: if http_res.status != 200:
raise MIoTHttpError( raise MIoTHttpError(
f'mihome api post failed, {http_res.status_code}, ' f'mihome api post failed, {http_res.status}, '
f'{url_path}, {data}') f'{url_path}, {data}')
res_obj: dict = http_res.json() res_str = await http_res.text()
res_obj: dict = json.loads(res_str)
if res_obj.get('code', None) != 0: if res_obj.get('code', None) != 0:
raise MIoTHttpError( raise MIoTHttpError(
f'invalid response code, {res_obj.get("code",None)}, ' f'invalid response code, {res_obj.get("code",None)}, '
@ -346,8 +338,8 @@ class MIoTHttpClient:
self._base_url, url_path, data, res_obj) self._base_url, url_path, data, res_obj)
return res_obj return res_obj
def get_user_info(self) -> dict: async def get_user_info_async(self) -> dict:
http_res = requests.get( http_res = await self._session.get(
url='https://open.account.xiaomi.com/user/profile', url='https://open.account.xiaomi.com/user/profile',
params={'clientId': self._client_id, params={'clientId': self._client_id,
'token': self._access_token}, 'token': self._access_token},
@ -355,7 +347,8 @@ class MIoTHttpClient:
timeout=MIHOME_HTTP_API_TIMEOUT timeout=MIHOME_HTTP_API_TIMEOUT
) )
res_obj = http_res.json() res_str = await http_res.text()
res_obj = json.loads(res_str)
if ( if (
not res_obj not res_obj
or res_obj.get('code', None) != 0 or res_obj.get('code', None) != 0
@ -366,14 +359,11 @@ class MIoTHttpClient:
return res_obj['data'] return res_obj['data']
async def get_user_info_async(self) -> dict: async def get_central_cert_async(self, csr: str) -> Optional[str]:
return await self.__call_async(partial(self.get_user_info))
def get_central_cert(self, csr: str) -> Optional[str]:
if not isinstance(csr, str): if not isinstance(csr, str):
raise MIoTHttpError('invalid params') raise MIoTHttpError('invalid params')
res_obj: dict = self.mihome_api_post( res_obj: dict = await self.__mihome_api_post_async(
url_path='/app/v2/ha/oauth/get_central_crt', url_path='/app/v2/ha/oauth/get_central_crt',
data={ data={
'csr': str(base64.b64encode(csr.encode('utf-8')), 'utf-8') 'csr': str(base64.b64encode(csr.encode('utf-8')), 'utf-8')
@ -387,11 +377,8 @@ class MIoTHttpClient:
return cert return cert
async def get_central_cert_async(self, csr: str) -> Optional[str]: async def __get_dev_room_page_async(self, max_id: str = None) -> dict:
return await self.__call_async(partial(self.get_central_cert, csr)) res_obj = await self.__mihome_api_post_async(
def __get_dev_room_page(self, max_id: str = None) -> dict:
res_obj = self.mihome_api_post(
url_path='/app/v2/homeroom/get_dev_room_page', url_path='/app/v2/homeroom/get_dev_room_page',
data={ data={
'start_id': max_id, 'start_id': max_id,
@ -419,7 +406,7 @@ class MIoTHttpClient:
res_obj['result'].get('has_more', False) res_obj['result'].get('has_more', False)
and isinstance(res_obj['result'].get('max_id', None), str) and isinstance(res_obj['result'].get('max_id', None), str)
): ):
next_list = self.__get_dev_room_page( next_list = await self.__get_dev_room_page_async(
max_id=res_obj['result']['max_id']) max_id=res_obj['result']['max_id'])
for home_id, info in next_list.items(): for home_id, info in next_list.items():
home_list.setdefault(home_id, {'dids': [], 'room_info': {}}) home_list.setdefault(home_id, {'dids': [], 'room_info': {}})
@ -432,8 +419,8 @@ class MIoTHttpClient:
return home_list return home_list
def get_homeinfos(self) -> dict: async def get_homeinfos_async(self) -> dict:
res_obj = self.mihome_api_post( res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/homeroom/gethome', url_path='/app/v2/homeroom/gethome',
data={ data={
'limit': 150, 'limit': 150,
@ -485,7 +472,7 @@ class MIoTHttpClient:
res_obj['result'].get('has_more', False) res_obj['result'].get('has_more', False)
and isinstance(res_obj['result'].get('max_id', None), str) and isinstance(res_obj['result'].get('max_id', None), str)
): ):
more_list = self.__get_dev_room_page( more_list = await self.__get_dev_room_page_async(
max_id=res_obj['result']['max_id']) max_id=res_obj['result']['max_id'])
for home_id, info in more_list.items(): for home_id, info in more_list.items():
if home_id not in home_infos['homelist']: if home_id not in home_infos['homelist']:
@ -507,16 +494,10 @@ class MIoTHttpClient:
'share_home_list': home_infos.get('share_home_list', []) 'share_home_list': home_infos.get('share_home_list', [])
} }
async def get_homeinfos_async(self) -> dict:
return await self.__call_async(self.get_homeinfos)
def get_uid(self) -> str:
return self.get_homeinfos().get('uid', None)
async def get_uid_async(self) -> str: async def get_uid_async(self) -> str:
return (await self.get_homeinfos_async()).get('uid', None) return (await self.get_homeinfos_async()).get('uid', None)
def __get_device_list_page( async def __get_device_list_page_async(
self, dids: list[str], start_did: str = None self, dids: list[str], start_did: str = None
) -> dict[str, dict]: ) -> dict[str, dict]:
req_data: dict = { req_data: dict = {
@ -527,7 +508,7 @@ class MIoTHttpClient:
if start_did: if start_did:
req_data['start_did'] = start_did req_data['start_did'] = start_did
device_infos: dict = {} device_infos: dict = {}
res_obj = self.mihome_api_post( res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/home/device_list_page', url_path='/app/v2/home/device_list_page',
data=req_data data=req_data
) )
@ -578,7 +559,7 @@ class MIoTHttpClient:
next_start_did = res_obj.get('next_start_did', None) next_start_did = res_obj.get('next_start_did', None)
if res_obj.get('has_more', False) and next_start_did: if res_obj.get('has_more', False) and next_start_did:
device_infos.update(self.__get_device_list_page( device_infos.update(await self.__get_device_list_page_async(
dids=dids, start_did=next_start_did)) dids=dids, start_did=next_start_did))
return device_infos return device_infos
@ -587,8 +568,7 @@ class MIoTHttpClient:
self, dids: list[str] self, dids: list[str]
) -> dict[str, dict]: ) -> dict[str, dict]:
results: list[dict[str, dict]] = await asyncio.gather( results: list[dict[str, dict]] = await asyncio.gather(
*[self.__call_async( *[self.__get_device_list_page_async(dids[index:index+150])
partial(self.__get_device_list_page, dids[index:index+150]))
for index in range(0, len(dids), 150)]) for index in range(0, len(dids), 150)])
devices = {} devices = {}
for result in results: for result in results:
@ -665,12 +645,12 @@ class MIoTHttpClient:
'devices': devices 'devices': devices
} }
def get_props(self, params: list) -> list: async def get_props_async(self, params: list) -> list:
""" """
params = [{"did": "xxxx", "siid": 2, "piid": 1}, params = [{"did": "xxxx", "siid": 2, "piid": 1},
{"did": "xxxxxx", "siid": 2, "piid": 2}] {"did": "xxxxxx", "siid": 2, "piid": 2}]
""" """
res_obj = self.mihome_api_post( res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/miotspec/prop/get', url_path='/app/v2/miotspec/prop/get',
data={ data={
'datasource': 1, 'datasource': 1,
@ -681,11 +661,9 @@ class MIoTHttpClient:
raise MIoTHttpError('invalid response result') raise MIoTHttpError('invalid response result')
return res_obj['result'] return res_obj['result']
async def get_props_async(self, params: list) -> list:
return await self.__call_async(partial(self.get_props, params))
def get_prop(self, did: str, siid: int, piid: int) -> any: async def __get_prop_async(self, did: str, siid: int, piid: int) -> any:
results = self.get_props( results = await self.get_props_async(
params=[{'did': did, 'siid': siid, 'piid': piid}]) params=[{'did': did, 'siid': siid, 'piid': piid}])
if not results: if not results:
return None return None
@ -711,7 +689,7 @@ class MIoTHttpClient:
if not props_buffer: if not props_buffer:
_LOGGER.error('get prop error, empty request list') _LOGGER.error('get prop error, empty request list')
return False return False
results = await self.__call_async(partial(self.get_props, props_buffer)) results = await self.get_props_async(props_buffer)
for result in results: for result in results:
if not all( if not all(
@ -747,8 +725,7 @@ class MIoTHttpClient:
self, did: str, siid: int, piid: int, immediately: bool = False self, did: str, siid: int, piid: int, immediately: bool = False
) -> any: ) -> any:
if immediately: if immediately:
return await self.__call_async( return await self.__get_prop_async(did, siid, piid)
partial(self.get_prop, did, siid, piid))
key: str = f'{did}.{siid}.{piid}' key: str = f'{did}.{siid}.{piid}'
prop_obj = self._get_prop_list.get(key, None) prop_obj = self._get_prop_list.get(key, None)
if prop_obj: if prop_obj:
@ -766,11 +743,11 @@ class MIoTHttpClient:
return await fut return await fut
def set_prop(self, params: list) -> list: async def set_prop_async(self, params: list) -> list:
""" """
params = [{"did": "xxxx", "siid": 2, "piid": 1, "value": False}] params = [{"did": "xxxx", "siid": 2, "piid": 1, "value": False}]
""" """
res_obj = self.mihome_api_post( res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/miotspec/prop/set', url_path='/app/v2/miotspec/prop/set',
data={ data={
'params': params 'params': params
@ -782,20 +759,14 @@ class MIoTHttpClient:
return res_obj['result'] return res_obj['result']
async def set_prop_async(self, params: list) -> list: async def action_async(
"""
params = [{"did": "xxxx", "siid": 2, "piid": 1, "value": False}]
"""
return await self.__call_async(partial(self.set_prop, params))
def action(
self, did: str, siid: int, aiid: int, in_list: list[dict] self, did: str, siid: int, aiid: int, in_list: list[dict]
) -> dict: ) -> dict:
""" """
params = {"did": "xxxx", "siid": 2, "aiid": 1, "in": []} params = {"did": "xxxx", "siid": 2, "aiid": 1, "in": []}
""" """
# NOTICE: Non-standard action param # NOTICE: Non-standard action param
res_obj = self.mihome_api_post( res_obj = await self.__mihome_api_post_async(
url_path='/app/v2/miotspec/action', url_path='/app/v2/miotspec/action',
data={ data={
'params': { 'params': {
@ -810,9 +781,3 @@ class MIoTHttpClient:
raise MIoTHttpError('invalid response result') raise MIoTHttpError('invalid response result')
return res_obj['result'] return res_obj['result']
async def action_async(
self, did: str, siid: int, aiid: int, in_list: list[dict]
) -> dict:
return await self.__call_async(
partial(self.action, did, siid, aiid, in_list))