# nobodd: a boot configuration tool for the Raspberry Pi
#
# Copyright (c) 2023-2024 Dave Jones <dave.jones@canonical.com>
# Copyright (c) 2023-2024 Canonical Ltd.
#
# SPDX-License-Identifier: GPL-3.0
import os
import codecs
from . import lang
# The following references were essential in constructing this module; the
# original TELNET specification [RFC764], and the wikipedia page documenting
# the TFTP protocol [1].
#
# [1]: https://en.wikipedia.org/wiki/Trivial_File_Transfer_Protocol
# [RFC764]: https://datatracker.ietf.org/doc/html/rfc764
# [RFC1350]: https://datatracker.ietf.org/doc/html/rfc1350
_netascii_linesep = os.linesep.encode('ascii')
[docs]
def encode(s, errors='strict', final=False):
"""
Encodes the :class:`str` *s*, which must only contain valid ASCII
characters, to the netascii :class:`bytes` representation.
The *errors* parameter specifies the handling of encoding errors in the
typical manner ('strict', 'ignore', 'replace', etc). The *final* parameter
indicates whether this is the end of the input. This only matters on the
Windows platform where the line separator is '\\r\\n' in which case a
trailing '\\r' character *may* be the start of a newline sequence.
The return value is a tuple of the encoded :class:`bytes` string, and the
number of characters consumed from *s* (this may be less than the length of
*s* when *final* is :data:`False`).
"""
# We can pre-allocate the output array as the transform guarantees the
# length of output <= 2 * length of the input (largest transform in all
# cases is b'\r' -> b'\r\0')
buf_in = s.encode('ascii', errors=errors)
buf_out = bytearray(len(buf_in) * 2)
pos_in = pos_out = 0
def encode_newline():
nonlocal buf_out, pos_out, pos_in
buf_out[pos_out:pos_out + 2] = b'\r\n'
pos_out += 2
pos_in += len(_netascii_linesep)
def encode_cr():
nonlocal buf_out, pos_out, pos_in
buf_out[pos_out:pos_out + 2] = b'\r\0'
pos_out += 2
pos_in += 1
while pos_in < len(buf_in):
i = min(
len(buf_in) if j == -1 else j
for j in (
buf_in.find(_netascii_linesep[0], pos_in),
buf_in.find(b'\r', pos_in),
)
)
if i > pos_in:
buf_out[pos_out:pos_out + i - pos_in] = buf_in[pos_in:i]
pos_out += i - pos_in
pos_in = i
elif len(_netascii_linesep) == 1:
# Non-windows case
if buf_in[i] == _netascii_linesep[0]:
encode_newline()
else: # buf_in[i] == b'\r'[0]
encode_cr()
else:
# Windows case
if len(buf_in) > pos_in + 1:
if buf_in[i + 1] == _netascii_linesep[1]:
encode_newline()
else:
encode_cr()
else:
if final:
encode_cr()
break
return bytes(buf_out[:pos_out]), pos_in
[docs]
def decode(s, errors='strict', final=False):
"""
Decodes the :class:`bytes` string *s*, which must contain a netascii
encoded string, to the :class:`str` representation (which can only contain
ASCII characters).
The *errors* parameter specifies the handling of encoding errors in the
typical manner ('strict', 'ignore', 'replace', etc). The *final* parameter
indicates whether this is the end of the input. This matters as a trailing
'\\r' in the input is the beginning of a newline sequence, an encoded
'\\r', or an error (in other cases).
The return value is a tuple of the decoded :class:`str`, and the number of
characters consumed from *s* (this may be less than the length of *s* when
*final* is :data:`False`).
"""
# We can pre-allocate the output array as the transform guarantees the
# length of output <= length of the input
buf_in = bytes(s)
buf_out = bytearray(len(buf_in))
pos_in = pos_out = 0
while pos_in < len(buf_in):
i = buf_in.find(b'\r', pos_in)
if i == -1:
i = len(buf_in)
if i > pos_in:
buf_out[pos_out:pos_out + i - pos_in] = buf_in[pos_in:i]
pos_out += i - pos_in
pos_in = i
elif len(buf_in) > pos_in + 1:
if buf_in[i + 1] == 0x0: # b'\0'
buf_out[pos_out] = 0xD # b'\r'
pos_out += 1
pos_in += 2
elif buf_in[i + 1] == 0xA: # b'\n'
buf_out[pos_out:pos_out + len(_netascii_linesep)] = _netascii_linesep
pos_out += len(_netascii_linesep)
pos_in += 2
else:
err_out = handle_error(errors)
buf_out[pos_out:pos_out + len(err_out)] = err_out
pos_out += len(err_out)
pos_in += 1
else:
if final:
err_out = handle_error(errors)
buf_out[pos_out:pos_out + len(err_out)] = err_out
pos_out += len(err_out)
pos_in += 1
break
return buf_out[:pos_out].decode('ascii', errors=errors), pos_in
def handle_error(errors):
if errors == 'strict':
raise UnicodeError(lang._('invalid netascii'))
elif errors == 'ignore':
return b''
elif errors == 'replace':
return b'?'
else:
raise ValueError(lang._('invalid errors setting for netascii'))
[docs]
class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
r"""
Use :func:`codecs.iterencode` to utilize this class for encoding:
.. code-block:: pycon
>>> import os
>>> os.linesep
'\n'
>>> import nobodd.netascii
>>> import codecs
>>> it = ['foo', '\n', 'bar\r']
>>> b''.join(codecs.iterencode(it, 'netascii'))
b'foo\r\nbar\r\0'
"""
@staticmethod
def _buffer_encode(s, errors, final=False):
return encode(s, errors, final)
[docs]
class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
r"""
Use :func:`codecs.iterdecode` to utilize this class for encoding:
.. code-block:: pycon
>>> import os
>>> os.linesep
'\n'
>>> import nobodd.netascii
>>> import codecs
>>> it = [b'foo\r', b'\n', b'bar\r', b'\0']
>>> ''.join(codecs.iterdecode(it, 'netascii'))
'foo\nbar\r'
"""
@staticmethod
def _buffer_decode(s, errors, final=False):
return decode(s, errors, final)
[docs]
class StreamWriter(codecs.StreamWriter):
def __init__(self, stream, errors='strict'):
super().__init__(stream, errors)
self._final = False
self.reset()
[docs]
def encode(self, s, errors='strict'):
encoded, consumed = encode(self._buf + s, errors, final=self._final)
self._buf = (self._buf + s)[consumed:]
return encoded, consumed
def flush(self):
self._final = True
try:
self.write('')
self.stream.flush()
finally:
self._final = False
[docs]
def reset(self):
super().reset()
self._buf = ''
[docs]
class StreamReader(codecs.StreamReader):
[docs]
def decode(self, s, errors='strict', final=False):
return decode(s, errors, final)
def stateless_encode(s, errors='strict'):
return encode(s, errors, final=True)
def stateless_decode(s, errors='strict'):
return decode(s, errors, final=True)
def find_netascii(name):
if name.lower() == 'netascii':
return codecs.CodecInfo(
name='netascii',
encode=stateless_encode,
decode=stateless_decode,
incrementalencoder=IncrementalEncoder,
incrementaldecoder=IncrementalDecoder,
streamreader=StreamReader,
streamwriter=StreamWriter,
)
codecs.register(find_netascii)