Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions appdaemon/plugins/hass/hassplugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datetime import datetime
from time import perf_counter
from typing import Any, Literal, Optional
from urllib.parse import urlencode

import aiohttp
from aiohttp import ClientResponseError, WebSocketError, WSMsgType
Expand Down Expand Up @@ -373,9 +374,6 @@ async def websocket_send_json(
Returns:
A dict containing the response from Home Assistant.
"""
request = utils.clean_kwargs(request)
request = utils.remove_literals(request, (None,))

if not self.connect_event.is_set():
self.logger.debug("Not connected to websocket, skipping JSON send.")
return
Expand All @@ -387,7 +385,7 @@ async def websocket_send_json(

if not silent:
# include this in the "not auth" section so we don't accidentally put the token in the logs
req_json = json.dumps(request, indent=4)
req_json = utils.convert_json(request, indent=4)
for i, line in enumerate(req_json.splitlines()):
if i == 0:
self.logger.debug(f"Sending JSON: {line}")
Expand All @@ -396,7 +394,7 @@ async def websocket_send_json(

send_time = perf_counter()
try:
await self.ws.send_json(request)
await self.ws.send_json(request, dumps=utils.convert_json)
# happens when the connection closes in the middle, which could be during shutdown
except ConnectionResetError:
if self.AD.stopping:
Expand All @@ -405,7 +403,7 @@ async def websocket_send_json(
else:
raise # Something bad actually happened, so raise the exception

self.update_perf(bytes_sent=len(json.dumps(request)), requests_sent=1)
self.update_perf(bytes_sent=len(utils.convert_json(request)), requests_sent=1)

match request:
case {"type": "auth"}:
Expand Down Expand Up @@ -454,25 +452,25 @@ async def http_method(
**kwargs (optional): Zero or more keyword arguments. These get used as the data for the method, as
appropriate.
"""
kwargs = utils.clean_http_kwargs(kwargs)
url = self.config.ha_url / endpoint.lstrip("/")

try:
self.update_perf(
bytes_sent=len(str(url)) + len(json.dumps(kwargs).encode("utf-8")),
requests_sent=1,
)

self.logger.debug(f"Hass {method.upper()} {endpoint}: {kwargs}")
match method.lower():
case "get":
http_method = functools.partial(self.session.get, params=kwargs)
cleaned = utils.clean_http_params_for_urlencode(kwargs)
payload_size = len(urlencode(cleaned).encode("utf-8"))
http_method = functools.partial(self.session.get, params=cleaned)
case "post":
payload_size = len(utils.convert_json(kwargs).encode("utf-8"))
http_method = functools.partial(self.session.post, json=kwargs)
case "delete":
http_method = functools.partial(self.session.delete, params=kwargs)
cleaned = utils.clean_http_params_for_urlencode(kwargs)
payload_size = len(urlencode(cleaned).encode("utf-8"))
http_method = functools.partial(self.session.delete, params=cleaned)
case _:
raise ValueError(f"Invalid method: {method}")
self.update_perf(bytes_sent=len(str(url)) + payload_size, requests_sent=1)

timeout = utils.parse_timedelta(timeout)
client_timeout = aiohttp.ClientTimeout(total=timeout.total_seconds())
Expand Down
64 changes: 35 additions & 29 deletions appdaemon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,49 +1189,55 @@ def time_str(start: float, now: float | None = None) -> str:
return format_timedelta((now or perf_counter()) - start)


def clean_kwargs(val: Any, *, http: bool = False) -> Any:
"""Recursively clean a dict of kwargs.

Conversions:
- datetime values are converted to ISO format strings
- Mapping values (like dicts) are converted to dicts of cleaned key-value pairs
- Iterable values (like lists and tuples) are converted to lists of cleaned values
- Other values are converted to strings
def remove_literals(val: Any, literal: Sequence[Any]) -> Any:
"""Remove instances of literals from a nested data structure.

Uses identity comparison (``is``) rather than equality (``==``)
to avoid ``0 == False`` and ``0.0 == False`` pitfalls.
"""
def _is_literal(v: Any) -> bool:
return any(v is lit for lit in literal)

match val:
case True if http:
return "true"
case str() | int() | float() | bool() | None:
case str():
return val
case datetime():
return val.isoformat()
case Mapping():
return {k: clean_kwargs(v, http=http) for k, v in val.items()}
return {k: remove_literals(v, literal) for k, v in val.items() if not _is_literal(v)}
case Iterable():
return [clean_kwargs(v, http=http) for v in val]
return [remove_literals(v, literal) for v in val if not _is_literal(v)]
case _:
return str(val)
return val


def remove_literals(val: Any, literal: Sequence[Any]) -> Any:
"""Remove instances of literals from a nested data structure."""
def clean_http_params_for_urlencode(val: Any) -> Any:
"""Recursively cleans kwargs for use as URL query parameters.

- None and False are excluded (HA treats param presence as enabled)
- True is converted to "true"
- datetime objects are converted to ISO format
- Other values are kept as-is
"""
match val:
case str():
case True:
return "true"
case str() | int() | float():
return val
case datetime():
return val.isoformat()
case Mapping():
return {k: remove_literals(v, literal) for k, v in val.items() if v not in literal}
return {
k: clean_http_params_for_urlencode(v)
for k, v in val.items()
if v is not None and v is not False
}
case Iterable():
return [remove_literals(v, literal) for v in val if v not in literal]
return [
clean_http_params_for_urlencode(v)
for v in val
if v is not None and v is not False
]
case _:
return val


def clean_http_kwargs(val: Any) -> Any:
"""Recursively cleans the kwarg dict to prepare it for use in HTTP requests."""
cleaned = clean_kwargs(val, http=True)
pruned = remove_literals(cleaned, (None, False))
return pruned
return str(val)


def unwrapped(func: Callable) -> Callable:
Expand Down
171 changes: 146 additions & 25 deletions tests/unit/test_kwarg_clean.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from copy import deepcopy
import json
from datetime import datetime

import pytest
import pytz
from appdaemon.utils import clean_http_kwargs, clean_kwargs, remove_literals
from appdaemon.utils import clean_http_params_for_urlencode, convert_json, remove_literals

pytestmark = [
pytest.mark.ci,
Expand All @@ -22,34 +22,42 @@
}


def test_clean_kwargs():
cleaned = clean_kwargs(BASE)
pruned = remove_literals(BASE, (None,))
assert isinstance(cleaned["f"], str)

def test_clean_http_params_for_urlencode():
cleaned = clean_http_params_for_urlencode(BASE)
assert cleaned["a"] == 1
assert cleaned["b"] == 2.0
assert cleaned["c"] == "three"
assert cleaned["d"] is True
assert cleaned["e"] is False
assert "g" not in pruned

kwargs = deepcopy(BASE)

kwargs["nested"] = deepcopy(BASE)
kwargs["nested"]["extra"] = deepcopy(BASE)
cleaned = clean_kwargs(kwargs)
assert isinstance(cleaned["nested"]["extra"]["f"], str)


def test_clean_http_kwargs():
cleaned = clean_http_kwargs(BASE)
assert isinstance(cleaned["f"], str)
assert cleaned["d"] == "true"
assert "e" not in cleaned
assert isinstance(cleaned["f"], str)
assert "g" not in cleaned


def test_clean_http_params_for_urlencode_preserves_zero():
"""0 and 0.0 must survive clean_http_params_for_urlencode (0 == False but 0 is not False)."""
data = {"offset": 0, "price": 0.0, "flag": False, "name": "test"}
cleaned = clean_http_params_for_urlencode(data)
assert cleaned["offset"] == 0
assert cleaned["price"] == 0.0
assert "flag" not in cleaned
assert cleaned["name"] == "test"


def test_clean_http_params_for_urlencode_nested():
"""Nested dicts and datetimes are cleaned recursively."""
data = {
"outer": {
"inner": {
"dt": datetime(2025, 9, 22, 12, 0, 0, tzinfo=pytz.utc),
"gone": None,
}
}
}
cleaned = clean_http_params_for_urlencode(data)
assert cleaned["outer"]["inner"]["dt"] == "2025-09-22T12:00:00+00:00"
assert "gone" not in cleaned["outer"]["inner"]


SERVICE_CALL = {
'type': 'call_service',
'domain': 'notify',
Expand All @@ -68,8 +76,9 @@ def test_clean_http_kwargs():
}


def test_websocket_service_call_kwargs():
cleaned = clean_kwargs(SERVICE_CALL)
def test_clean_http_params_for_urlencode_complex_nested():
"""Complex nested structure (like a service call) is cleaned correctly."""
cleaned = clean_http_params_for_urlencode(SERVICE_CALL)
match cleaned:
case {
"service_data":
Expand All @@ -87,7 +96,119 @@ def test_websocket_service_call_kwargs():
case _:
assert False, "Action format incorrect"
case _:
assert False, "Action format incorrect"
assert False, "Structure format incorrect"


def test_remove_literals_strips_none_from_service_call():
pruned = remove_literals(SERVICE_CALL, (None,))
assert "timeout" not in pruned["service_data"]


def test_remove_literals_preserves_zero():
"""remove_literals must use identity (is), not equality (==), to avoid 0 == False."""
data = {"a": 0, "b": 0.0, "c": False, "d": None, "e": "hello"}
pruned = remove_literals(data, (None, False))
assert pruned["a"] == 0
assert pruned["b"] == 0.0
assert "c" not in pruned
assert "d" not in pruned
assert pruned["e"] == "hello"


class TestConvertJson:
"""convert_json is the JSON serializer used by the aiohttp session and websocket."""

def test_datetime_uses_isoformat(self):
dt = datetime(2025, 6, 15, 10, 0, 0, tzinfo=pytz.utc)
result = convert_json({"timestamp": dt})
parsed = json.loads(result)
assert parsed["timestamp"] == "2025-06-15T10:00:00+00:00"

def test_booleans_are_json_booleans(self):
result = convert_json({"flag": True, "other": False})
parsed = json.loads(result)
assert parsed["flag"] is True
assert parsed["other"] is False

def test_none_becomes_null(self):
result = convert_json({"value": None})
parsed = json.loads(result)
assert parsed["value"] is None

def test_zero_preserved(self):
result = convert_json({"rate": 0, "price": 0.0})
parsed = json.loads(result)
assert parsed["rate"] == 0
assert parsed["price"] == 0.0

def test_unknown_type_falls_back_to_str(self):
class Custom:
def __str__(self):
return "custom_value"

result = convert_json({"obj": Custom()})
parsed = json.loads(result)
assert parsed["obj"] == "custom_value"


class TestSetStateRegression:
"""Regression tests for set_state scenarios from issues #2531, #2464, #2492.

These simulate what happens when set_state kwargs pass through
session.post(json=kwargs) with convert_json as the serializer
(the transparent POST path).
"""

def test_issue_2531_false_and_zero_attributes(self):
"""Reproduces the exact scenario from issue #2531."""
kwargs = {
"state": 1,
"attributes": {
"rate": 0,
"friendly_name": "Test Entity",
"unit_of_measurement": "GBP/kWh",
"plunge": False,
"plunge_start": False,
},
}
result = json.loads(convert_json(kwargs))
assert result["state"] == 1
assert result["attributes"]["rate"] == 0
assert result["attributes"]["plunge"] is False
assert result["attributes"]["plunge_start"] is False
assert result["attributes"]["friendly_name"] == "Test Entity"

def test_issue_2492_zero_float_in_nested_dict(self):
"""Reproduces the scenario from issue #2492 where 0.0 prices vanished."""
kwargs = {
"state": "0.08",
"attributes": {
"prices": {
"2025-11-29T00:00:00+02:00": {"price": 0.0, "intervals": 4},
"2025-11-29T11:00:00+02:00": {"price": 0.08, "intervals": 4},
}
},
}
result = json.loads(convert_json(kwargs))
prices = result["attributes"]["prices"]
assert prices["2025-11-29T00:00:00+02:00"]["price"] == 0.0
assert prices["2025-11-29T11:00:00+02:00"]["price"] == 0.08

def test_none_attribute_preserved_as_null(self):
"""None values in attributes should become JSON null, not be dropped."""
kwargs = {
"state": "on",
"attributes": {
"optional_field": None,
"name": "test",
},
}
result = json.loads(convert_json(kwargs))
assert "optional_field" in result["attributes"]
assert result["attributes"]["optional_field"] is None

def test_state_zero_preserved(self):
"""state=0 must not be dropped."""
kwargs = {"state": 0, "attributes": {"icon": "mdi:radiator"}}
result = json.loads(convert_json(kwargs))
assert result["state"] == 0
Loading