diff --git a/bot.py b/bot.py index b9d9cde..51bd796 100644 --- a/bot.py +++ b/bot.py @@ -4,6 +4,8 @@ import asyncio import inspect import logging +import irc + logging.basicConfig(format="[%(asctime)s] [%(levelname)s] %(message)s", level=logging.DEBUG, datefmt="%d.%m.%Y %H:%M:%S") logger = logging.getLogger(__name__) @@ -13,11 +15,12 @@ class ManagedProtocol(asyncio.Protocol): Inherit this to overlay the management with actual protocol parsing. """ - def __init__(self, loop, connection_manager, endpoint): + def __init__(self, config=None, loop=None, connection_manager=None, endpoint=None): self._loop = loop self._connection_manager = connection_manager self._endpoint = endpoint self._transport = None + self._config = config def _log(self, msg): host, port = self._endpoint @@ -48,60 +51,72 @@ class ManagedProtocol(asyncio.Protocol): """ Triggered by ConnectionManager.remove_endpoint(). Closes transport. """ self._transport.close() + def get_config(self): + return self._config class IrcProtocol(ManagedProtocol): """Implementation of the IRC protocol. """ - def __init__(self, loop, connection_manager, endpoint): - super(IrcProtocol, self).__init__(loop, connection_manager, endpoint) - self._loop = loop + def __init__(self, *args, **kwargs): + super(IrcProtocol, self).__init__(*args, **kwargs) self.motd = False self.hello = False + self._config = self.get_config() + self._buffer = b"" + + def encode(self, str): + return str.encode(self._config["encoding"], "replace") + + def decode(self, bytes): + return bytes.decode(self._config["encoding"], "replace") def connection_made(self, transport): super(IrcProtocol, self).connection_made(transport) - self.send_data(b"USER as as as :as\r\n") - self.send_data(b"NICK Pb42\r\n") - pass + self.send_data(b"USER " + self.encode(self._config["user"]) + b" dummy dummy :" + + self.encode(self._config["realname"]) + b"\r\n") + self.send_data(b"NICK " + self.encode(self._config["nick"]) + b"\r\n") def data_received(self, data): super(IrcProtocol, self).data_received(data) - pass - - def eof_received(self): - super(IrcProtocol, self).eof_received() - pass - - def connection_lost(self, exc): - super(IrcProtocol, self).connection_lost() - pass + self._buffer += data + self.process_data() + def process_data(self): + while b'\r\n' in self._buffer: + line, self._buffer = self._buffer.split(b'\r\n', 1) + line = self.decode(line.strip()) + irc_line = irc.IrcLine.from_string(line) + print(self.encode(str(irc_line))) class ConnectionManager(object): """Takes care of known endpoints that a connections shall be established to. + Stores configurations for every configuration. """ def __init__(self, loop): self._loop = loop self._endpoints = [] + self._configs = {} self._active_connections = {} self._loop.set_exception_handler(self._handle_async_exception) - def add_endpoint(self, endpoint): + def add_endpoint(self, endpoint, config): logger.debug("Endpoint added: {}:{}".format(*endpoint)) self._endpoints.append(endpoint) + self._configs[endpoint] = config self._create_connection(endpoint) def _create_connection(self, endpoint): - protocol = IrcProtocol(self._loop, self, endpoint) + protocol = IrcProtocol(config=self._configs[endpoint], loop=self._loop, connection_manager=self, endpoint=endpoint) coroutine = self._loop.create_connection(lambda: protocol, *endpoint) - asyncio.async(coroutine) + asyncio.ensure_future(coroutine) def remove_endpoint(self, endpoint): logger.debug("Endpoint removed: {}:{}".format(*endpoint)) self._endpoints.remove(endpoint) + del self._configs[endpoint] if endpoint in self._active_connections: self._active_connections[endpoint].close() @@ -128,14 +143,16 @@ class ConnectionManager(object): if __name__ == "__main__": - freenode = ("irc.freenode.net", 6667) - euirc = ("irc.euirc.net", 6667) - loop = asyncio.get_event_loop() connection_manager = ConnectionManager(loop) - connection_manager.add_endpoint(euirc) - connection_manager.add_endpoint(freenode) + connection_manager.add_endpoint(("irc.freenode.net", 6667), { + "encoding": "utf-8", + "nick": "Pb42", + "user": "foobar", + "realname": "Baz McBatzen", + "channels": ["#botted"] + }) try: loop.run_forever() diff --git a/irc.py b/irc.py new file mode 100644 index 0000000..55bcbe5 --- /dev/null +++ b/irc.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +def parse(line): + if line[0:1] == ':': + prefix, line = line.split(None, 1) + prefix = prefix[1:] + else: + prefix = "" + if ' :' in line: + tmp_str, trailing = line.split(' :', 1) + tmp_args = tmp_str.split() + else: + trailing = "" + tmp_args = line.split() + command, *middle = tmp_args + params = middle[:] + return prefix, command, params, trailing + + +class IrcLine(object): + """Handles translation between strings and IrcLines + """ + def __init__(self): + self.prefix = "" + self.command = "" + self.params = "" + self.trailing = "" + + @classmethod + def from_string(cls, string): + instance = cls() + data = parse(string) + print(data) + instance.prefix = data[0] + instance.command = data[1] + instance.params = data[2] + instance.trailing = data[3] + return instance + + def __repr__(self): + e = [] + if self.prefix: + e.append(self.prefix) + if self.command: + e.append(self.command) + if self.params: + e.append(" ".join(self.params)) + if self.trailing: + e.append(":{}".format(self.trailing)) + result = " ".join(e) + return result + + @classmethod + def kick(cls, channel, user, msg="KICK"): + instance = cls() + instance.command = "KICK" + instance.params = [channel, user] + instance.trailing = msg + return instance + +if __name__ == '__main__': + l = IrcLine.from_string(":JPT|NC!~AS@euirc-6f528752.pools.arcor-ip.net JOIN :#euirc") + print(str(l)) + print() + + l = IrcLine.from_string(":ChanServ!services@euirc.net MODE #Tonari. +ao JPT JPT") + print(str(l)) + print() + + line = IrcLine.kick("#botted", "JPT", "Du Sack!") + print(str(line)) + line2 = IrcLine.from_string(str(line)) + print(str(line2)) + exit() + +