Merge pull request #94 from paul-nameless/more-mypy-checks

More strict mypy checks
This commit is contained in:
Nameless 2020-06-30 15:09:43 +08:00 committed by GitHub
commit 60278f98ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 269 additions and 220 deletions

View file

@ -44,4 +44,4 @@ jobs:
- name: Check types with mypy
run: |
mypy tg --warn-redundant-casts --warn-unused-ignores --no-warn-no-return --warn-unreachable --strict-equality --ignore-missing-imports
mypy tg --warn-redundant-casts --warn-unused-ignores --no-warn-no-return --warn-unreachable --strict-equality --ignore-missing-imports --warn-unused-configs --disallow-untyped-calls --disallow-untyped-defs --disallow-incomplete-defs --check-untyped-defs --disallow-untyped-decorators

2
.gitignore vendored
View file

@ -5,3 +5,5 @@ __pycache__
dist
*.log*
Makefile
.idea/
*monkeytype.sqlite3

View file

@ -21,4 +21,4 @@ line-length = 79
[tool.isort]
line_length = 79
multi_line_output = 3
include_trailing_comma = true
include_trailing_comma = true

View file

@ -8,6 +8,7 @@ from queue import Queue
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, List, Optional
from telegram.utils import AsyncResult
from tg import config
from tg.models import Model
from tg.msg import MsgProxy
@ -30,26 +31,26 @@ log = logging.getLogger(__name__)
# cause blan areas on the msg display screen
MSGS_LEFT_SCROLL_THRESHOLD = 2
REPLY_MSG_PREFIX = "# >"
handler_type = Callable[[Any], Any]
HandlerType = Callable[[Any], Optional[str]]
chat_handler: Dict[str, handler_type] = {}
msg_handler: Dict[str, handler_type] = {}
chat_handler: Dict[str, HandlerType] = {}
msg_handler: Dict[str, HandlerType] = {}
def bind(
binding: Dict[str, handler_type],
binding: Dict[str, HandlerType],
keys: List[str],
repeat_factor: bool = False,
):
) -> Callable:
"""bind handlers to given keys"""
def decorator(fun):
def decorator(fun: Callable) -> HandlerType:
@wraps(fun)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> Any:
return fun(*args, **kwargs)
@wraps(fun)
def _no_repeat_factor(self, repeat_factor):
def _no_repeat_factor(self: "Controller", _: bool) -> Any:
return fun(self)
for key in keys:
@ -70,9 +71,8 @@ class Controller:
self.chat_size = 0.5
@bind(msg_handler, ["o"])
def open_url(self):
msg = self.model.current_msg
msg = MsgProxy(msg)
def open_url(self) -> None:
msg = MsgProxy(self.model.current_msg)
if not msg.is_text:
self.present_error("Does not contain urls")
return
@ -95,41 +95,42 @@ class Controller:
with suspend(self.view) as s:
s.run_with_input(config.URL_VIEW, "\n".join(urls))
def format_help(self, bindings):
@staticmethod
def format_help(bindings: Dict[str, HandlerType]) -> str:
return "\n".join(
f"{key}\t{fun.__name__}\t{fun.__doc__ or ''}"
for key, fun in sorted(bindings.items())
)
@bind(chat_handler, ["?"])
def show_chat_help(self):
def show_chat_help(self) -> None:
_help = self.format_help(chat_handler)
with suspend(self.view) as s:
s.run_with_input(config.HELP_CMD, _help)
@bind(msg_handler, ["?"])
def show_msg_help(self):
def show_msg_help(self) -> None:
_help = self.format_help(msg_handler)
with suspend(self.view) as s:
s.run_with_input(config.HELP_CMD, _help)
@bind(chat_handler, ["bp"])
@bind(msg_handler, ["bp"])
def breakpoint(self):
def breakpoint(self) -> None:
with suspend(self.view):
breakpoint()
@bind(chat_handler, ["q"])
@bind(msg_handler, ["q"])
def quit(self):
def quit(self) -> str:
return "QUIT"
@bind(msg_handler, ["h", "^D"])
def back(self):
def back(self) -> str:
return "BACK"
@bind(msg_handler, ["p"])
def forward_msgs(self):
def forward_msgs(self) -> None:
"""Paste yanked msgs"""
if not self.model.forward_msgs():
self.present_error("Can't forward msg(s)")
@ -137,7 +138,7 @@ class Controller:
self.present_info("Forwarded msg(s)")
@bind(msg_handler, ["y"])
def yank_msgs(self):
def yank_msgs(self) -> None:
"""Copy msgs to clipboard and internal buffer to forward"""
chat_id = self.model.chats.id_by_index(self.model.current_chat)
if not chat_id:
@ -152,7 +153,7 @@ class Controller:
self.present_info(f"Copied {len(msg_ids)} msg(s)")
@bind(msg_handler, [" "])
def toggle_select_msg(self):
def toggle_select_msg(self) -> None:
chat_id = self.model.chats.id_by_index(self.model.current_chat)
if not chat_id:
return
@ -166,7 +167,7 @@ class Controller:
self.render_msgs()
@bind(msg_handler, ["^G", "^["])
def discard_selected_msgs(self):
def discard_selected_msgs(self) -> None:
chat_id = self.model.chats.id_by_index(self.model.current_chat)
if not chat_id:
return
@ -175,34 +176,36 @@ class Controller:
self.present_info("Discarded selected messages")
@bind(msg_handler, ["G"])
def bottom_msg(self):
def bottom_msg(self) -> None:
if self.model.jump_bottom():
self.render_msgs()
@bind(msg_handler, ["j", "^B", "^N"], repeat_factor=True)
def next_msg(self, repeat_factor: int = 1):
def next_msg(self, repeat_factor: int = 1) -> None:
if self.model.next_msg(repeat_factor):
self.render_msgs()
@bind(msg_handler, ["J"])
def jump_10_msgs_down(self):
def jump_10_msgs_down(self) -> None:
self.next_msg(10)
@bind(msg_handler, ["k", "^C", "^P"], repeat_factor=True)
def prev_msg(self, repeat_factor: int = 1):
def prev_msg(self, repeat_factor: int = 1) -> None:
if self.model.prev_msg(repeat_factor):
self.render_msgs()
@bind(msg_handler, ["K"])
def jump_10_msgs_up(self):
def jump_10_msgs_up(self) -> None:
self.prev_msg(10)
@bind(msg_handler, ["r"])
def reply_message(self):
def reply_message(self) -> None:
if not self.can_send_msg():
self.present_info("Can't send msg in this chat")
return
chat_id = self.model.current_chat_id
if chat_id is None:
return
reply_to_msg = self.model.current_msg_id
if msg := self.view.status.get_input():
self.tg.reply_message(chat_id, reply_to_msg, msg)
@ -211,11 +214,13 @@ class Controller:
self.present_info("Message reply wasn't sent")
@bind(msg_handler, ["R"])
def reply_with_long_message(self):
def reply_with_long_message(self) -> None:
if not self.can_send_msg():
self.present_info("Can't send msg in this chat")
return
chat_id = self.model.current_chat_id
if chat_id is None:
return
reply_to_msg = self.model.current_msg_id
msg = MsgProxy(self.model.current_msg)
with NamedTemporaryFile("w+", suffix=".txt") as f, suspend(
@ -225,18 +230,18 @@ class Controller:
f.seek(0)
s.call(config.LONG_MSG_CMD.format(file_path=shlex.quote(f.name)))
with open(f.name) as f:
if msg := strip_replied_msg(f.read().strip()):
self.tg.reply_message(chat_id, reply_to_msg, msg)
if replied_msg := strip_replied_msg(f.read().strip()):
self.tg.reply_message(chat_id, reply_to_msg, replied_msg)
self.present_info("Message sent")
else:
self.present_info("Message wasn't sent")
@bind(msg_handler, ["a", "i"])
def write_short_msg(self):
if not self.can_send_msg():
def write_short_msg(self) -> None:
chat_id = self.model.chats.id_by_index(self.model.current_chat)
if not self.can_send_msg() or chat_id is None:
self.present_info("Can't send msg in this chat")
return
chat_id = self.model.chats.id_by_index(self.model.current_chat)
self.tg.send_chat_action(chat_id, ChatAction.chatActionTyping)
if msg := self.view.status.get_input():
self.model.send_message(text=msg)
@ -246,14 +251,14 @@ class Controller:
self.present_info("Message wasn't sent")
@bind(msg_handler, ["A", "I"])
def write_long_msg(self):
if not self.can_send_msg():
def write_long_msg(self) -> None:
chat_id = self.model.chats.id_by_index(self.model.current_chat)
if not self.can_send_msg() or chat_id is None:
self.present_info("Can't send msg in this chat")
return
with NamedTemporaryFile("r+", suffix=".txt") as f, suspend(
self.view
) as s:
chat_id = self.model.chats.id_by_index(self.model.current_chat)
self.tg.send_chat_action(chat_id, ChatAction.chatActionTyping)
s.call(config.LONG_MSG_CMD.format(file_path=shlex.quote(f.name)))
with open(f.name) as f:
@ -267,7 +272,7 @@ class Controller:
self.present_info("Message wasn't sent")
@bind(msg_handler, ["sv"])
def send_video(self):
def send_video(self) -> None:
file_path = self.view.status.get_input()
if not file_path or not os.path.isfile(file_path):
return
@ -279,7 +284,7 @@ class Controller:
self.tg.send_video(file_path, chat_id, width, height, duration)
@bind(msg_handler, ["dd"])
def delete_msgs(self):
def delete_msgs(self) -> None:
is_deleted = self.model.delete_msgs()
self.discard_selected_msgs()
if not is_deleted:
@ -288,26 +293,30 @@ class Controller:
self.present_info("Message deleted")
@bind(msg_handler, ["sd"])
def send_document(self):
def send_document(self) -> None:
self.send_file(self.tg.send_doc)
@bind(msg_handler, ["sp"])
def send_picture(self):
def send_picture(self) -> None:
self.send_file(self.tg.send_photo)
@bind(msg_handler, ["sa"])
def send_audio(self):
def send_audio(self) -> None:
self.send_file(self.tg.send_audio)
def send_file(self, send_file_fun, *args, **kwargs):
def send_file(
self, send_file_fun: Callable[[str, int], AsyncResult],
) -> None:
file_path = self.view.status.get_input()
if file_path and os.path.isfile(file_path):
chat_id = self.model.chats.id_by_index(self.model.current_chat)
send_file_fun(file_path, chat_id, *args, **kwargs)
self.present_info("File sent")
if chat_id := self.model.chats.id_by_index(
self.model.current_chat
):
send_file_fun(file_path, chat_id)
self.present_info("File sent")
@bind(msg_handler, ["v"])
def record_voice(self):
def record_voice(self) -> None:
file_path = f"/tmp/voice-{datetime.now()}.oga"
with suspend(self.view) as s:
s.call(
@ -335,7 +344,7 @@ class Controller:
self.present_info(f"Sent voice msg: {file_path}")
@bind(msg_handler, ["D"])
def download_current_file(self):
def download_current_file(self) -> None:
msg = MsgProxy(self.model.current_msg)
log.debug("Downloading msg: %s", msg.msg)
file_id = msg.file_id
@ -345,7 +354,7 @@ class Controller:
self.download(file_id, msg["chat_id"], msg["id"])
self.present_info("File started downloading")
def download(self, file_id: int, chat_id: int, msg_id: int):
def download(self, file_id: int, chat_id: int, msg_id: int) -> None:
log.info("Downloading file: file_id=%s", file_id)
self.model.downloads[file_id] = (chat_id, msg_id)
self.tg.download_file(file_id=file_id)
@ -356,7 +365,7 @@ class Controller:
return chat["permissions"]["can_send_messages"]
@bind(msg_handler, ["l", "^J"])
def open_current_msg(self):
def open_current_msg(self) -> None:
msg = MsgProxy(self.model.current_msg)
if msg.is_text:
with NamedTemporaryFile("w", suffix=".txt") as f:
@ -378,7 +387,7 @@ class Controller:
s.open_file(path)
@bind(msg_handler, ["e"])
def edit_msg(self):
def edit_msg(self) -> None:
msg = MsgProxy(self.model.current_msg)
log.info("Editing msg: %s", msg.msg)
if not self.model.is_me(msg.sender_id):
@ -400,7 +409,7 @@ class Controller:
self.present_info("Message edited")
@bind(chat_handler, ["l", "^J", "^E"])
def handle_msgs(self):
def handle_msgs(self) -> Optional[str]:
rc = self.handle(msg_handler, 0.2)
if rc == "QUIT":
return rc
@ -408,32 +417,32 @@ class Controller:
self.resize()
@bind(chat_handler, ["g"])
def top_chat(self):
def top_chat(self) -> None:
if self.model.first_chat():
self.render()
@bind(chat_handler, ["j", "^B", "^N"], repeat_factor=True)
@bind(msg_handler, ["]"])
def next_chat(self, repeat_factor: int = 1):
def next_chat(self, repeat_factor: int = 1) -> None:
if self.model.next_chat(repeat_factor):
self.render()
@bind(chat_handler, ["k", "^C", "^P"], repeat_factor=True)
@bind(msg_handler, ["["])
def prev_chat(self, repeat_factor: int = 1):
def prev_chat(self, repeat_factor: int = 1) -> None:
if self.model.prev_chat(repeat_factor):
self.render()
@bind(chat_handler, ["J"])
def jump_10_chats_down(self):
def jump_10_chats_down(self) -> None:
self.next_chat(10)
@bind(chat_handler, ["K"])
def jump_10_chats_up(self):
def jump_10_chats_up(self) -> None:
self.prev_chat(10)
@bind(chat_handler, ["u"])
def toggle_unread(self):
def toggle_unread(self) -> None:
chat = self.model.chats.chats[self.model.current_chat]
chat_id = chat["id"]
toggle = not chat["is_marked_as_unread"]
@ -441,7 +450,7 @@ class Controller:
self.render()
@bind(chat_handler, ["r"])
def read_msgs(self):
def read_msgs(self) -> None:
chat = self.model.chats.chats[self.model.current_chat]
chat_id = chat["id"]
msg_id = chat["last_message"]["id"]
@ -449,7 +458,7 @@ class Controller:
self.render()
@bind(chat_handler, ["m"])
def toggle_mute(self):
def toggle_mute(self) -> None:
# TODO: if it's msg to yourself, do not change its
# notification setting, because we can't by documentation,
# instead write about it in status
@ -467,7 +476,7 @@ class Controller:
self.render()
@bind(chat_handler, ["p"])
def toggle_pin(self):
def toggle_pin(self) -> None:
chat = self.model.chats.chats[self.model.current_chat]
chat_id = chat["id"]
toggle = not chat["is_pinned"]
@ -481,10 +490,10 @@ class Controller:
except Exception:
log.exception("Error happened in main loop")
def close(self):
def close(self) -> None:
self.is_running = False
def handle(self, handlers: Dict[str, handler_type], size: float):
def handle(self, handlers: Dict[str, HandlerType], size: float) -> str:
self.chat_size = size
self.resize()
@ -497,15 +506,15 @@ class Controller:
elif res == "BACK":
return res
def resize_handler(self, signum, frame):
def resize_handler(self, signum: int, frame: Any) -> None:
curses.endwin()
self.view.stdscr.refresh()
self.resize()
def resize(self):
def resize(self) -> None:
self.queue.put(self._resize)
def _resize(self):
def _resize(self) -> None:
rows, cols = self.view.stdscr.getmaxyx()
# If we didn't clear the screen before doing this,
# the original window contents would remain on the screen
@ -518,7 +527,7 @@ class Controller:
self.view.status.resize(rows, cols)
self.render()
def draw(self):
def draw(self) -> None:
while self.is_running:
try:
log.info("Queue size: %d", self.queue.qsize())
@ -527,16 +536,16 @@ class Controller:
except Exception:
log.exception("Error happened in draw loop")
def present_error(self, msg: str):
def present_error(self, msg: str) -> None:
return self.update_status("Error", msg)
def present_info(self, msg: str):
def present_info(self, msg: str) -> None:
return self.update_status("Info", msg)
def update_status(self, level: str, msg: str):
def update_status(self, level: str, msg: str) -> None:
self.queue.put(partial(self._update_status, level, msg))
def _update_status(self, level: str, msg: str):
def _update_status(self, level: str, msg: str) -> None:
self.view.status.draw(f"{level}: {msg}")
def render(self) -> None:
@ -577,10 +586,9 @@ class Controller:
current_msg_idx, msgs, MSGS_LEFT_SCROLL_THRESHOLD, chat
)
def notify_for_message(self, chat_id: int, msg: MsgProxy):
def notify_for_message(self, chat_id: int, msg: MsgProxy) -> None:
# do not notify, if muted
# TODO: optimize
chat = None
for chat in self.model.chats.chats:
if chat_id == chat["id"]:
break
@ -598,10 +606,10 @@ class Controller:
user = self.model.users.get_user(msg.sender_id)
name = f"{user['first_name']} {user['last_name']}"
text = msg.text_content if msg.is_text else msg.content_type
notify(text, title=name)
if text := msg.text_content if msg.is_text else msg.content_type:
notify(text, title=name)
def _refresh_current_chat(self, current_chat_id: Optional[int]):
def refresh_current_chat(self, current_chat_id: Optional[int]) -> None:
if current_chat_id is None:
return
# TODO: we can create <index> for chats, it's faster than sqlite anyway
@ -616,6 +624,8 @@ class Controller:
def insert_replied_msg(msg: MsgProxy) -> str:
text = msg.text_content if msg.is_text else msg.content_type
if not text:
return ""
return (
"\n".join([f"{REPLY_MSG_PREFIX} {line}" for line in text.split("\n")])
# adding line with whitespace so text editor could start editing from last line

View file

@ -1,9 +1,9 @@
import logging
import logging.handlers
import signal
import threading
from curses import window, wrapper # type: ignore
from functools import partial
from types import FrameType
from tg import config, update_handlers, utils
from tg.controllers import Controller
@ -17,7 +17,7 @@ log = logging.getLogger(__name__)
def run(tg: Tdlib, stdscr: window) -> None:
# handle ctrl+c, to avoid interrupting tg when subprocess is called
def interrupt_signal_handler(sig, frame):
def interrupt_signal_handler(sig: int, frame: FrameType) -> None:
# TODO: draw on status pane: to quite press <q>
log.info("Interrupt signal is handled and ignored on purpose.")
@ -36,14 +36,14 @@ def run(tg: Tdlib, stdscr: window) -> None:
for msg_type, handler in update_handlers.handlers.items():
tg.add_update_handler(msg_type, partial(handler, controller))
thread = threading.Thread(target=controller.run,)
thread = threading.Thread(target=controller.run)
thread.daemon = True
thread.start()
controller.draw()
def main():
def main() -> None:
tg = Tdlib(
api_id=config.API_ID,
api_hash=config.API_HASH,

View file

@ -1,7 +1,7 @@
import logging
import time
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from tg.msg import MsgProxy
from tg.tdlib import ChatAction, Tdlib, UserStatus
@ -11,7 +11,7 @@ log = logging.getLogger(__name__)
class Model:
def __init__(self, tg: Tdlib):
def __init__(self, tg: Tdlib) -> None:
self.tg = tg
self.chats = ChatModel(tg)
self.msgs = MsgModel(tg)
@ -21,13 +21,13 @@ class Model:
self.selected: Dict[int, List[int]] = defaultdict(list)
self.copied_msgs: Tuple[int, List[int]] = (0, [])
def get_me(self):
def get_me(self) -> Dict[str, Any]:
return self.users.get_me()
def is_me(self, user_id: int) -> bool:
return self.get_me()["id"] == user_id
def get_user(self, user_id):
def get_user(self, user_id: int) -> Dict:
return self.users.get_user(user_id)
@property
@ -68,9 +68,10 @@ class Model:
def current_msg_id(self) -> int:
return self.current_msg["id"]
def jump_bottom(self):
chat_id = self.chats.id_by_index(self.current_chat)
return self.msgs.jump_bottom(chat_id)
def jump_bottom(self) -> bool:
if chat_id := self.chats.id_by_index(self.current_chat):
return self.msgs.jump_bottom(chat_id)
return False
def next_chat(self, step: int = 1) -> bool:
new_idx = self.current_chat + step
@ -85,17 +86,17 @@ class Model:
self.current_chat = max(0, self.current_chat - step)
return True
def first_chat(self):
def first_chat(self) -> bool:
if self.current_chat != 0:
self.current_chat = 0
return True
return False
def view_current_msg(self):
chat_id = self.chats.id_by_index(self.current_chat)
def view_current_msg(self) -> None:
msg = MsgProxy(self.current_msg)
msg_id = msg["id"]
self.tg.view_messages(chat_id, [msg_id])
if chat_id := self.chats.id_by_index(self.current_chat):
self.tg.view_messages(chat_id, [msg_id])
def next_msg(self, step: int = 1) -> bool:
chat_id = self.chats.id_by_index(self.current_chat)
@ -120,7 +121,7 @@ class Model:
current_position: int = 0,
page_size: int = 10,
msgs_left_scroll_threshold: int = 10,
):
) -> List[Dict[str, Any]]:
chats_left = page_size - current_position
offset = max(msgs_left_scroll_threshold - chats_left, 0)
limit = offset + page_size
@ -181,7 +182,7 @@ class Model:
self.copied_msgs = (0, [])
return True
def copy_msgs_text(self):
def copy_msgs_text(self) -> bool:
"""Copies current msg text or path to file if it's file"""
buffer = []
@ -193,15 +194,16 @@ class Model:
if not _msg:
return False
msg = MsgProxy(_msg)
if msg.file_id:
if msg.file_id and msg.local_path:
buffer.append(msg.local_path)
elif msg.is_text:
buffer.append(msg.text_content)
copy_to_clipboard("\n".join(buffer))
return True
class ChatModel:
def __init__(self, tg: Tdlib):
def __init__(self, tg: Tdlib) -> None:
self.tg = tg
self.chats: List[Dict[str, Any]] = []
self.chat_ids: List[int] = []
@ -221,7 +223,7 @@ class ChatModel:
return self.chats[offset:limit]
def _load_next_chats(self):
def _load_next_chats(self) -> None:
"""
based on
https://github.com/tdlib/td/issues/56#issuecomment-364221408
@ -283,7 +285,7 @@ class ChatModel:
class MsgModel:
def __init__(self, tg: Tdlib):
def __init__(self, tg: Tdlib) -> None:
self.tg = tg
self.msgs: Dict[int, List[Dict]] = defaultdict(list)
self.current_msgs: Dict[int, int] = defaultdict(int)
@ -297,7 +299,7 @@ class MsgModel:
self.current_msgs[chat_id] = max(0, current_msg - step)
return True
def jump_bottom(self, chat_id: int):
def jump_bottom(self, chat_id: int) -> bool:
if self.current_msgs[chat_id] == 0:
return False
self.current_msgs[chat_id] = 0
@ -324,7 +326,7 @@ class MsgModel:
return result.update
return next(iter(m for m in self.msgs[chat_id] if m["id"] == msg_id))
def remove_message(self, chat_id: int, msg_id: int):
def remove_message(self, chat_id: int, msg_id: int) -> bool:
msg_set = self.msg_ids[chat_id]
if msg_id not in msg_set:
return False
@ -336,7 +338,7 @@ class MsgModel:
msg_set.remove(msg_id)
return True
def update_msg_content_opened(self, chat_id: int, msg_id: int):
def update_msg_content_opened(self, chat_id: int, msg_id: int) -> None:
for message in self.msgs[chat_id]:
if message["id"] != msg_id:
continue
@ -350,7 +352,9 @@ class MsgModel:
# https://core.telegram.org/tdlib/docs/classtd_1_1td__api_1_1update_message_content_opened.html
return
def update_msg(self, chat_id: int, msg_id: int, **fields: Dict[str, Any]):
def update_msg(
self, chat_id: int, msg_id: int, **fields: Dict[str, Any]
) -> bool:
msg = None
for message in self.msgs[chat_id]:
if message["id"] == msg_id:
@ -464,14 +468,14 @@ class UserModel:
def __init__(self, tg: Tdlib) -> None:
self.tg = tg
self.me = None
self.me: Dict[str, Any] = {}
self.users: Dict[int, Dict] = {}
self.groups: Dict[int, Dict] = {}
self.supergroups: Dict[int, Dict] = {}
self.actions: Dict[int, Dict] = {}
self.not_found: Set[int] = set()
def get_me(self):
def get_me(self) -> Dict[str, Any]:
if self.me:
return self.me
result = self.tg.get_me()
@ -493,7 +497,7 @@ class UserModel:
log.error(f"ChatAction type {action_type} not implemented")
return None
def set_status(self, user_id: int, status: Dict[str, Any]):
def set_status(self, user_id: int, status: Dict[str, Any]) -> None:
if user_id not in self.users:
self.get_user(user_id)
self.users[user_id]["status"] = status
@ -522,7 +526,7 @@ class UserModel:
return f"last seen {ago}"
return f"last seen {status.value}"
def is_online(self, user_id: int):
def is_online(self, user_id: int) -> bool:
user = self.get_user(user_id)
if (
user

View file

@ -32,7 +32,7 @@ class MsgProxy:
}
@classmethod
def get_doc(cls, msg, deep=10):
def get_doc(cls, msg: Dict[str, Any], deep: int = 10) -> Dict[str, Any]:
doc = msg["content"]
_type = doc["@type"]
fields = cls.fields_mapping.get(_type)
@ -48,13 +48,13 @@ class MsgProxy:
return {}
return doc
def __init__(self, msg: Dict[str, Any]):
def __init__(self, msg: Dict[str, Any]) -> None:
self.msg = msg
def __getitem__(self, key: str) -> Any:
return self.msg[key]
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
self.msg[key] = value
@property
@ -66,39 +66,39 @@ class MsgProxy:
return datetime.fromtimestamp(self.msg["date"])
@property
def is_message(self):
def is_message(self) -> bool:
return self.type == "message"
@property
def content_type(self):
def content_type(self) -> Optional[str]:
return self.types.get(self.msg["content"]["@type"])
@property
def size(self):
def size(self) -> int:
doc = self.get_doc(self.msg)
return doc["size"]
@property
def human_size(self):
def human_size(self) -> str:
doc = self.get_doc(self.msg)
return utils.humanize_size(doc["size"])
@property
def duration(self):
def duration(self) -> Optional[str]:
if self.content_type not in ("audio", "voice", "video", "recording"):
return None
doc = self.get_doc(self.msg, deep=1)
return utils.humanize_duration(doc["duration"])
@property
def file_name(self):
def file_name(self) -> Optional[str]:
if self.content_type not in ("audio", "document", "video"):
return None
doc = self.get_doc(self.msg, deep=1)
return doc["file_name"]
@property
def file_id(self):
def file_id(self) -> Optional[int]:
if self.content_type not in (
"audio",
"document",
@ -113,26 +113,26 @@ class MsgProxy:
return doc["id"]
@property
def local_path(self):
def local_path(self) -> Optional[str]:
if self.msg["content"]["@type"] is None:
return None
doc = self.get_doc(self.msg)
return doc["local"]["path"]
@property
def local(self):
def local(self) -> Dict:
doc = self.get_doc(self.msg)
return doc["local"]
@local.setter
def local(self, value):
def local(self, value: Dict) -> None:
if self.msg["content"]["@type"] is None:
return None
return
doc = self.get_doc(self.msg)
doc["local"] = value
@property
def is_text(self):
def is_text(self) -> bool:
return self.msg["content"]["@type"] == "messageText"
@property
@ -140,7 +140,7 @@ class MsgProxy:
return self.msg["content"]["text"]["text"]
@property
def is_downloaded(self):
def is_downloaded(self) -> bool:
doc = self.get_doc(self.msg)
return doc["local"]["is_downloading_completed"]
@ -151,7 +151,7 @@ class MsgProxy:
return self.msg["content"]["is_listened"]
@is_listened.setter
def is_listened(self, value: bool):
def is_listened(self, value: bool) -> None:
if self.content_type == "voice":
self.msg["content"]["is_listened"] = value
@ -162,7 +162,7 @@ class MsgProxy:
return self.msg["content"]["is_viewed"]
@is_viewed.setter
def is_viewed(self, value: bool):
def is_viewed(self, value: bool) -> None:
if self.content_type == "recording":
self.msg["content"]["is_viewed"] = value

View file

@ -39,8 +39,13 @@ class UserStatus(Enum):
class Tdlib(Telegram):
def download_file(
self, file_id, priority=16, offset=0, limit=0, synchronous=False,
):
self,
file_id: int,
priority: int = 16,
offset: int = 0,
limit: int = 0,
synchronous: bool = False,
) -> None:
result = self.call_method(
"downloadFile",
params=dict(
@ -124,8 +129,8 @@ class Tdlib(Telegram):
return self._send_data(data)
def send_voice(
self, file_path: str, chat_id: int, duration: int, waveform: int
):
self, file_path: str, chat_id: int, duration: int, waveform: str
) -> AsyncResult:
data = {
"@type": "sendMessage",
"chat_id": chat_id,
@ -138,7 +143,9 @@ class Tdlib(Telegram):
}
return self._send_data(data)
def edit_message_text(self, chat_id: int, message_id: int, text: str):
def edit_message_text(
self, chat_id: int, message_id: int, text: str
) -> AsyncResult:
data = {
"@type": "editMessageText",
"message_id": message_id,
@ -172,7 +179,7 @@ class Tdlib(Telegram):
def set_chat_nottification_settings(
self, chat_id: int, notification_settings: dict
):
) -> AsyncResult:
data = {
"@type": "setChatNotificationSettings",
"chat_id": chat_id,

View file

@ -8,15 +8,17 @@ from tg.msg import MsgProxy
log = logging.getLogger(__name__)
_update_handler_type = Callable[[Controller, Dict[str, Any]], None]
UpdateHandler = Callable[[Controller, Dict[str, Any]], None]
handlers: Dict[str, _update_handler_type] = {}
handlers: Dict[str, UpdateHandler] = {}
max_download_size: int = utils.parse_size(config.MAX_DOWNLOAD_SIZE)
def update_handler(update_type):
def decorator(fun):
def update_handler(
update_type: str,
) -> Callable[[UpdateHandler], UpdateHandler]:
def decorator(fun: UpdateHandler) -> UpdateHandler:
global handlers
assert (
update_type not in handlers
@ -25,9 +27,9 @@ def update_handler(update_type):
handlers[update_type] = fun
@wraps(fun)
def wrapper(*args, **kwargs):
def wrapper(controller: Controller, update: Dict[str, Any]) -> None:
try:
return fun(*args, **kwargs)
return fun(controller, update)
except Exception:
log.exception("Error happened in %s handler", fun.__name__)
@ -37,7 +39,9 @@ def update_handler(update_type):
@update_handler("updateMessageContent")
def update_message_content(controller: Controller, update: Dict[str, Any]):
def update_message_content(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["chat_id"]
message_id = update["message_id"]
controller.model.msgs.update_msg(
@ -50,7 +54,9 @@ def update_message_content(controller: Controller, update: Dict[str, Any]):
@update_handler("updateMessageEdited")
def update_message_edited(controller: Controller, update: Dict[str, Any]):
def update_message_edited(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["chat_id"]
message_id = update["message_id"]
edit_date = update["edit_date"]
@ -62,7 +68,7 @@ def update_message_edited(controller: Controller, update: Dict[str, Any]):
@update_handler("updateNewMessage")
def update_new_message(controller: Controller, update: Dict[str, Any]):
def update_new_message(controller: Controller, update: Dict[str, Any]) -> None:
msg = MsgProxy(update["message"])
controller.model.msgs.add_message(msg.chat_id, msg.msg)
current_chat_id = controller.model.current_chat_id
@ -75,29 +81,29 @@ def update_new_message(controller: Controller, update: Dict[str, Any]):
@update_handler("updateChatOrder")
def update_chat_order(controller: Controller, update: Dict[str, Any]):
def update_chat_order(controller: Controller, update: Dict[str, Any]) -> None:
current_chat_id = controller.model.current_chat_id
chat_id = update["chat_id"]
order = update["order"]
if controller.model.chats.update_chat(chat_id, order=order):
controller._refresh_current_chat(current_chat_id)
controller.refresh_current_chat(current_chat_id)
@update_handler("updateChatTitle")
def update_chat_title(controller: Controller, update: Dict[str, Any]):
def update_chat_title(controller: Controller, update: Dict[str, Any]) -> None:
chat_id = update["chat_id"]
title = update["title"]
current_chat_id = controller.model.current_chat_id
if controller.model.chats.update_chat(chat_id, title=title):
controller._refresh_current_chat(current_chat_id)
controller.refresh_current_chat(current_chat_id)
@update_handler("updateChatIsMarkedAsUnread")
def update_chat_is_marked_as_unread(
controller: Controller, update: Dict[str, Any]
):
) -> None:
chat_id = update["chat_id"]
is_marked_as_unread = update["is_marked_as_unread"]
@ -105,11 +111,13 @@ def update_chat_is_marked_as_unread(
if controller.model.chats.update_chat(
chat_id, is_marked_as_unread=is_marked_as_unread
):
controller._refresh_current_chat(current_chat_id)
controller.refresh_current_chat(current_chat_id)
@update_handler("updateChatIsPinned")
def update_chat_is_pinned(controller: Controller, update: Dict[str, Any]):
def update_chat_is_pinned(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["chat_id"]
is_pinned = update["is_pinned"]
order = update["order"]
@ -118,23 +126,27 @@ def update_chat_is_pinned(controller: Controller, update: Dict[str, Any]):
if controller.model.chats.update_chat(
chat_id, is_pinned=is_pinned, order=order
):
controller._refresh_current_chat(current_chat_id)
controller.refresh_current_chat(current_chat_id)
@update_handler("updateChatReadOutbox")
def update_chat_read_outbox(controller: Controller, update: Dict[str, Any]):
def update_chat_read_outbox(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["chat_id"]
last_read_outbox_message_id = update["last_read_outbox_message_id"]
current_chat_id = controller.model.current_chat_id
if controller.model.chats.update_chat(
chat_id, last_read_outbox_message_id=last_read_outbox_message_id,
chat_id, last_read_outbox_message_id=last_read_outbox_message_id
):
controller._refresh_current_chat(current_chat_id)
controller.refresh_current_chat(current_chat_id)
@update_handler("updateChatReadInbox")
def update_chat_read_inbox(controller: Controller, update: Dict[str, Any]):
def update_chat_read_inbox(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["chat_id"]
last_read_inbox_message_id = update["last_read_inbox_message_id"]
unread_count = update["unread_count"]
@ -145,11 +157,13 @@ def update_chat_read_inbox(controller: Controller, update: Dict[str, Any]):
last_read_inbox_message_id=last_read_inbox_message_id,
unread_count=unread_count,
):
controller._refresh_current_chat(current_chat_id)
controller.refresh_current_chat(current_chat_id)
@update_handler("updateChatDraftMessage")
def update_chat_draft_message(controller: Controller, update: Dict[str, Any]):
def update_chat_draft_message(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["chat_id"]
# FIXME: ignoring draft message itself for now because UI can't show it
# draft_message = update["draft_message"]
@ -157,11 +171,13 @@ def update_chat_draft_message(controller: Controller, update: Dict[str, Any]):
current_chat_id = controller.model.current_chat_id
if controller.model.chats.update_chat(chat_id, order=order):
controller._refresh_current_chat(current_chat_id)
controller.refresh_current_chat(current_chat_id)
@update_handler("updateChatLastMessage")
def update_chat_last_message(controller: Controller, update: Dict[str, Any]):
def update_chat_last_message(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["chat_id"]
last_message = update.get("last_message")
if not last_message:
@ -174,11 +190,13 @@ def update_chat_last_message(controller: Controller, update: Dict[str, Any]):
if controller.model.chats.update_chat(
chat_id, last_message=last_message, order=order
):
controller._refresh_current_chat(current_chat_id)
controller.refresh_current_chat(current_chat_id)
@update_handler("updateChatNotificationSettings")
def update_chat_notification_settings(controller: Controller, update):
def update_chat_notification_settings(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["chat_id"]
notification_settings = update["notification_settings"]
if controller.model.chats.update_chat(
@ -188,7 +206,9 @@ def update_chat_notification_settings(controller: Controller, update):
@update_handler("updateMessageSendSucceeded")
def update_message_send_succeeded(controller: Controller, update):
def update_message_send_succeeded(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["message"]["chat_id"]
msg_id = update["old_message_id"]
controller.model.msgs.add_message(chat_id, update["message"])
@ -200,7 +220,7 @@ def update_message_send_succeeded(controller: Controller, update):
@update_handler("updateFile")
def update_file(controller: Controller, update):
def update_file(controller: Controller, update: Dict[str, Any]) -> None:
file_id = update["file"]["id"]
local = update["file"]["local"]
chat_id, msg_id = controller.model.downloads.get(file_id, (None, None))
@ -223,7 +243,7 @@ def update_file(controller: Controller, update):
@update_handler("updateMessageContentOpened")
def update_message_content_opened(
controller: Controller, update: Dict[str, Any]
):
) -> None:
chat_id = update["chat_id"]
message_id = update["message_id"]
controller.model.msgs.update_msg_content_opened(chat_id, message_id)
@ -231,7 +251,9 @@ def update_message_content_opened(
@update_handler("updateDeleteMessages")
def update_delete_messages(controller: Controller, update: Dict[str, Any]):
def update_delete_messages(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["chat_id"]
msg_ids = update["message_ids"]
for msg_id in msg_ids:
@ -240,7 +262,9 @@ def update_delete_messages(controller: Controller, update: Dict[str, Any]):
@update_handler("updateConnectionState")
def update_connection_state(controller: Controller, update: Dict[str, Any]):
def update_connection_state(
controller: Controller, update: Dict[str, Any]
) -> None:
state = update["state"]["@type"]
states = {
"connectionStateWaitingForNetwork": "Waiting for network...",
@ -255,27 +279,29 @@ def update_connection_state(controller: Controller, update: Dict[str, Any]):
@update_handler("updateUserStatus")
def update_user_status(controller: Controller, update: Dict[str, Any]):
def update_user_status(controller: Controller, update: Dict[str, Any]) -> None:
controller.model.users.set_status(update["user_id"], update["status"])
controller.render()
@update_handler("updateBasicGroup")
def update_basic_group(controller: Controller, update: Dict[str, Any]):
def update_basic_group(controller: Controller, update: Dict[str, Any]) -> None:
basic_group = update["basic_group"]
controller.model.users.groups[basic_group["id"]] = basic_group
controller.render_msgs()
@update_handler("updateSupergroup")
def update_supergroup(controller: Controller, update: Dict[str, Any]):
def update_supergroup(controller: Controller, update: Dict[str, Any]) -> None:
supergroup = update["supergroup"]
controller.model.users.supergroups[supergroup["id"]] = supergroup
controller.render_msgs()
@update_handler("updateUserChatAction")
def update_user_chat_action(controller: Controller, update: Dict[str, Any]):
def update_user_chat_action(
controller: Controller, update: Dict[str, Any]
) -> None:
chat_id = update["chat_id"]
if update["action"]["@type"] == "chatActionCancel":
controller.model.users.actions.pop(chat_id, None)

View file

@ -12,8 +12,9 @@ import struct
import subprocess
import sys
from datetime import datetime
from functools import wraps
from typing import Optional
from logging.handlers import RotatingFileHandler
from types import TracebackType
from typing import Any, Optional, Tuple, Type
from tg import config
@ -33,29 +34,30 @@ units = {"B": 1, "KB": 10 ** 3, "MB": 10 ** 6, "GB": 10 ** 9, "TB": 10 ** 12}
class LogWriter:
def __init__(self, level):
def __init__(self, level: Any) -> None:
self.level = level
def write(self, message):
def write(self, message: str) -> None:
if message != "\n":
self.level.log(self.level, message)
def flush(self):
def flush(self) -> None:
pass
def setup_log():
def setup_log() -> None:
handlers = []
for level, filename in zip(
(config.LOG_LEVEL, logging.ERROR), ("all.log", "error.log"),
for level, filename in (
(config.LOG_LEVEL, "all.log"),
(logging.ERROR, "error.log"),
):
handler = logging.handlers.RotatingFileHandler(
handler = RotatingFileHandler(
os.path.join(config.LOG_PATH, filename),
maxBytes=parse_size("32MB"),
backupCount=1,
)
handler.setLevel(level)
handler.setLevel(level) # type: ignore
handlers.append(handler)
logging.basicConfig(
@ -63,11 +65,11 @@ def setup_log():
handlers=handlers,
)
logging.getLogger().setLevel(config.LOG_LEVEL)
sys.stderr = LogWriter(log.error)
sys.stderr = LogWriter(log.error) # type: ignore
logging.captureWarnings(True)
def get_file_handler(file_path, default=None):
def get_file_handler(file_path: str, default: str = None) -> Optional[str]:
mtype, _ = mimetypes.guess_type(file_path)
if not mtype:
return default
@ -87,8 +89,10 @@ def parse_size(size: str) -> int:
def humanize_size(
num, suffix="B", suffixes=("", "K", "M", "G", "T", "P", "E", "Z")
):
num: int,
suffix: str = "B",
suffixes: Tuple[str, ...] = ("", "K", "M", "G", "T", "P", "E", "Z",),
) -> str:
magnitude = int(math.floor(math.log(num, 1024)))
val = num / math.pow(1024, magnitude)
if magnitude > 7:
@ -96,7 +100,7 @@ def humanize_size(
return "{:3.1f}{}{}".format(val, suffixes[magnitude], suffix)
def humanize_duration(seconds):
def humanize_duration(seconds: int) -> str:
dt = datetime.utcfromtimestamp(seconds)
fmt = "%-M:%S"
if seconds >= 3600:
@ -111,13 +115,13 @@ def num(value: str, default: Optional[int] = None) -> Optional[int]:
return default
def is_yes(resp):
def is_yes(resp: str) -> bool:
if resp.strip().lower() == "y" or resp == "":
return True
return False
def get_duration(file_path):
def get_duration(file_path: str) -> int:
cmd = f"ffprobe -v error -i '{file_path}' -show_format"
stdout = subprocess.check_output(shlex.split(cmd)).decode().splitlines()
line = next((line for line in stdout if "duration" in line), None)
@ -128,14 +132,14 @@ def get_duration(file_path):
return 0
def get_video_resolution(file_path):
def get_video_resolution(file_path: str) -> Tuple[int, int]:
cmd = f"ffprobe -v error -show_entries stream=width,height -of default=noprint_wrappers=1 '{file_path}'"
lines = subprocess.check_output(shlex.split(cmd)).decode().splitlines()
info = {line.split("=")[0]: line.split("=")[1] for line in lines}
return info.get("width"), info.get("height")
return int(str(info.get("width"))), int(str(info.get("height")))
def get_waveform(file_path):
def get_waveform(file_path: str) -> str:
# mock for now
waveform = (random.randint(0, 255) for _ in range(100))
packed = struct.pack("100B", *waveform)
@ -143,8 +147,11 @@ def get_waveform(file_path):
def notify(
msg, subtitle="", title="tg", cmd=config.NOTIFY_CMD,
):
msg: str,
subtitle: str = "",
title: str = "tg",
cmd: str = config.NOTIFY_CMD,
) -> None:
if not cmd:
return
notify_cmd = cmd.format(
@ -156,45 +163,35 @@ def notify(
os.system(notify_cmd)
def handle_exception(fun):
@wraps(fun)
def wrapper(*args, **kwargs):
try:
return fun(*args, **kwargs)
except Exception:
log.exception("Error happened in %s handler", fun.__name__)
return wrapper
def truncate_to_len(s: str, target_len: int, encoding: str = "utf-8") -> str:
def truncate_to_len(s: str, target_len: int) -> str:
target_len -= sum(map(bool, map(emoji_pattern.findall, s[:target_len])))
return s[: max(1, target_len - 1)]
def copy_to_clipboard(text):
def copy_to_clipboard(text: str) -> None:
subprocess.run(
config.COPY_CMD, universal_newlines=True, input=text, shell=True
)
class suspend:
def __init__(self, view):
# FIXME: can't explicitly set type "View" due to circular import
def __init__(self, view: Any) -> None:
self.view = view
def call(self, cmd):
def call(self, cmd: str) -> None:
subprocess.call(cmd, shell=True)
def run_with_input(self, cmd, text):
def run_with_input(self, cmd: str, text: str) -> None:
subprocess.run(cmd, universal_newlines=True, input=text, shell=True)
def open_file(self, file_path):
def open_file(self, file_path: str) -> None:
cmd = get_file_handler(file_path)
if not cmd:
return
self.call(cmd)
def __enter__(self):
def __enter__(self) -> "suspend":
for view in (self.view.chats, self.view.msgs, self.view.status):
view._refresh = view.win.noutrefresh
curses.echo()
@ -204,7 +201,12 @@ class suspend:
curses.endwin()
return self
def __exit__(self, exc_type, exc_val, tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
for view in (self.view.chats, self.view.msgs, self.view.status):
view._refresh = view.win.refresh
curses.noecho()
@ -214,7 +216,7 @@ class suspend:
curses.doupdate()
def set_shorter_esc_delay(delay=25):
def set_shorter_esc_delay(delay: int = 25) -> None:
os.environ.setdefault("ESCDELAY", str(delay))

View file

@ -2,7 +2,7 @@ import curses
import logging
from _curses import window # type: ignore
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from tg import config
from tg.colors import (
@ -92,7 +92,7 @@ class StatusView:
self.win = stdscr.subwin(self.h, self.w, self.y, self.x)
self._refresh = self.win.refresh
def resize(self, rows: int, cols: int):
def resize(self, rows: int, cols: int) -> None:
self.w = cols - 1
self.y = rows - 1
self.win.resize(self.h, self.w)
@ -105,7 +105,7 @@ class StatusView:
self.win.addstr(0, 0, msg[: self.w])
self._refresh()
def get_input(self, msg="") -> str:
def get_input(self, msg: str = "") -> str:
self.draw(msg)
curses.curs_set(1)
@ -136,7 +136,7 @@ class StatusView:
class ChatView:
def __init__(self, stdscr: window, model: Model):
def __init__(self, stdscr: window, model: Model) -> None:
self.stdscr = stdscr
self.h = 0
self.w = 0
@ -258,9 +258,7 @@ class ChatView:
class MsgView:
def __init__(
self, stdscr: window, model: Model,
):
def __init__(self, stdscr: window, model: Model,) -> None:
self.model = model
self.stdscr = stdscr
self.h = 0
@ -280,7 +278,7 @@ class MsgView:
self.win.resize(self.h, self.w)
self.win.mvwin(0, self.x)
def _get_flags(self, msg_proxy: MsgProxy):
def _get_flags(self, msg_proxy: MsgProxy) -> str:
flags = []
chat = self.model.chats.chats[self.model.current_chat]
@ -329,7 +327,7 @@ class MsgView:
msg = f"{reply_line}\n{msg}"
return msg
def _format_url(self, msg_proxy: MsgProxy):
def _format_url(self, msg_proxy: MsgProxy) -> str:
if not msg_proxy.is_text or "web_page" not in msg_proxy.msg["content"]:
return ""
web = msg_proxy.msg["content"]["web_page"]
@ -493,7 +491,7 @@ class MsgView:
log.error(f"ChatType {chat['type']} not implemented")
return None
def _msg_title(self, chat: Dict[str, Any]):
def _msg_title(self, chat: Dict[str, Any]) -> str:
chat_type = self._get_chat_type(chat)
status = ""
if action := self.model.users.get_action(chat["id"]):
@ -598,11 +596,11 @@ def format_bool(value: Optional[bool]) -> Optional[str]:
return "yes" if value else "no"
def get_download(local, size):
def get_download(local: Dict[str, Union[str, bool, int]], size: int) -> str:
if local["is_downloading_completed"]:
return "yes"
elif local["is_downloading_active"]:
d = local["downloaded_size"]
d = int(local["downloaded_size"])
percent = int(d * 100 / size)
return f"{percent}%"
return "no"