Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 73 additions & 63 deletions machine_learning/apriori_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
"""

from collections import Counter
from collections import defaultdict
from itertools import combinations


Expand All @@ -25,78 +25,88 @@ def load_data() -> list[list[str]]:
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]


def prune(itemset: list, candidates: list, length: int) -> list:
"""
Prune candidate itemsets that are not frequent.
The goal of pruning is to filter out candidate itemsets that are not frequent. This
is done by checking if all the (k-1) subsets of a candidate itemset are present in
the frequent itemsets of the previous iteration (valid subsequences of the frequent
itemsets from the previous iteration).

Prunes candidate itemsets that are not frequent.

>>> itemset = ['X', 'Y', 'Z']
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
>>> prune(itemset, candidates, 2)
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]

>>> itemset = ['1', '2', '3', '4']
>>> candidates = ['1', '2', '4']
>>> prune(itemset, candidates, 3)
[]
# ---------- Helpers ----------


def get_support(itemset: frozenset, transactions: list[set]) -> int:
"""Compute support count of an itemset efficiently."""
return sum(1 for t in transactions if itemset.issubset(t))


def generate_candidates(prev_frequent: set[frozenset], k: int) -> set[frozenset]:
"""
itemset_counter = Counter(tuple(item) for item in itemset)
pruned = []
for candidate in candidates:
is_subsequence = True
for item in candidate:
item_tuple = tuple(item)
if (
item_tuple not in itemset_counter
or itemset_counter[item_tuple] < length - 1
):
is_subsequence = False
break
if is_subsequence:
pruned.append(candidate)
return pruned


def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
Generate candidate itemsets of size k from frequent itemsets of size k-1.
"""
Returns a list of frequent itemsets and their support counts.
prev_list = list(prev_frequent)
candidates = set()

>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
>>> apriori(data, 2)
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]
for i in range(len(prev_list)):
for j in range(i + 1, len(prev_list)):
union = prev_list[i] | prev_list[j]
if len(union) == k:
candidates.add(union)

>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
>>> apriori(data, 3)
[]
return candidates


def has_infrequent_subset(candidate: frozenset, prev_frequent: set[frozenset]) -> bool:
"""
itemset = [list(transaction) for transaction in data]
frequent_itemsets = []
length = 1
Apriori pruning: all (k-1)-subsets must be frequent.
"""
for subset in combinations(candidate, len(candidate) - 1):
if frozenset(subset) not in prev_frequent:
return True
return False


# ---------- Main Apriori ----------


def apriori(data: list[list[str]], min_support: int) -> list[tuple[frozenset, int]]:
transactions = [set(t) for t in data]

# 1. initial 1-itemsets
item_counts = defaultdict(int)
for t in transactions:
for item in t:
item_counts[frozenset([item])] += 1

frequent = {
itemset for itemset, count in item_counts.items() if count >= min_support
}

all_frequents = [
(next(iter(i)), c) for i, c in item_counts.items() if c >= min_support
]

k = 2

while frequent:
# 2. generate candidates
candidates = generate_candidates(frequent, k)

# 3. prune
candidates = {c for c in candidates if not has_infrequent_subset(c, frequent)}

while itemset:
# Count itemset support
counts = [0] * len(itemset)
for transaction in data:
for j, candidate in enumerate(itemset):
if all(item in transaction for item in candidate):
counts[j] += 1
# 4. count support
candidate_counts = defaultdict(int)
for t in transactions:
for c in candidates:
if c.issubset(t):
candidate_counts[c] += 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using defaultdict(int) here is cleaner than the old counts = [0] *
len(itemset) approach — no more index tracking with enumerate().

One suggestion: the support counting loop (lines 93-96) could use
the new get_support() helper you defined earlier to avoid duplication
and keep the main apriori() function cleaner:

frequent = {
c: get_support(c, transactions)
for c in candidates
if get_support(c, transactions) >= min_support
}

Copy link
Copy Markdown
Author

@JossGeek JossGeek Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actual logic

The actual implementation is optimal:

for t in transactions:
    for c in candidates:
        if c.issubset(t):
            candidate_counts[c] += 1

It passes over each transactions once, count all candidates at once, and avoid repeated scans, which make it algorithmically better.

Your suggestion

frequent = {
    c: get_support(c, transactions)
    for c in candidates
    if get_support(c, transactions) >= min_support
}

It calls the get_suppor() twice per candidate, which literally double the cost:

  • once in if
  • once in the value

So it improves readability, but definitely not the performance.

Aligned with your suggestion

candidate_counts = {}

for c in candidates:
    support = get_support(c, transactions)
    if support >= min_support:
        candidate_counts[c] = support

Here we computes support only once avoiding duplications, but we keep the logic readable. This way, we assure readability and performance.


# Prune infrequent itemsets
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]
# 5. filter frequent
frequent = {c for c, count in candidate_counts.items() if count >= min_support}

# Append frequent itemsets (as a list to maintain order)
for i, item in enumerate(itemset):
frequent_itemsets.append((sorted(item), counts[i]))
all_frequents.extend(
(sorted(c), count)
for c, count in candidate_counts.items()
if count >= min_support
)

length += 1
itemset = prune(itemset, list(combinations(itemset, length)), length)
k += 1

return frequent_itemsets
return all_frequents


if __name__ == "__main__":
Expand Down