From c1230cd187cf292f67ca905c9928680852420cda Mon Sep 17 00:00:00 2001 From: Jinsun Yoo Date: Wed, 20 May 2026 17:36:46 +0000 Subject: [PATCH] Changes needed for pypi packaging --- USER_GUIDE.md | 9 +++++---- pyproject.toml | 4 ++-- requirements-dev.txt | 2 ++ src/trace_link/chakra_device_trace_loader.py | 9 ++++++++- src/trace_link/chakra_host_trace_loader.py | 11 +++++++++-- src/trace_link/et_replay_import_error.py | 15 +++++++++++++++ src/trace_link/kineto_operator.py | 9 ++++++++- src/trace_link/trace_linker.py | 18 +++++++++++++----- .../test_chakra_host_trace_loader.py | 10 +++++++++- tests/trace_link/test_trace_linker.py | 19 ++++++++++++++----- 10 files changed, 85 insertions(+), 21 deletions(-) create mode 100644 src/trace_link/et_replay_import_error.py diff --git a/USER_GUIDE.md b/USER_GUIDE.md index 164aaeb6..9246fd1e 100644 --- a/USER_GUIDE.md +++ b/USER_GUIDE.md @@ -15,6 +15,9 @@ $ source chakra_env/bin/activate With the virtual environment activated, install the Chakra package using pip. ```bash +# Install package from PyPi +$ pip install mlc-chakra + # Install package from source $ pip install . @@ -29,13 +32,11 @@ $ pip install https://github.com/mlcommons/chakra/archive/ae7c671db702eb1384015b Installing PARAM is necessary for Chakra to function properly as it imports essential components from it. ```bash -$ git clone git@github.com:facebookresearch/param.git -$ cd param/et_replay -$ git checkout 7b19f586dd8b267333114992833a0d7e0d601630 -$ pip install . +$ pip install "git+https://github.com/facebookresearch/param.git@7b19f586dd8b267333114992833a0d7e0d601630#subdirectory=et_replay" ``` ### Step 4: Install Holistic Trace Analysis +Skip this step if you installed the PyPi package. Installing Holistic Trace Analysis is necessary for Trace link. ```bash diff --git a/pyproject.toml b/pyproject.toml index 99193c03..cab1c3c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "setuptools>=61", build-backend = "setuptools.build_meta" [project] -name = "chakra" +name = "mlc-chakra" requires-python = ">=3.7" version = "1.0.0" readme = "README.md" @@ -19,7 +19,7 @@ dependencies = [ "graphviz", "networkx", "pydot", - "HolisticTraceAnalysis @ git+https://github.com/facebookresearch/HolisticTraceAnalysis.git@d731cc2e2249976c97129d409a83bd53d93051f6" + "HolisticTraceAnalysis <= 0.5.0", ] [project.urls] diff --git a/requirements-dev.txt b/requirements-dev.txt index 810997f1..f3212fc7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,5 @@ pyright==1.1.359 pytest==8.1.1 ruff==0.3.7 vulture==2.11 +build +twine diff --git a/src/trace_link/chakra_device_trace_loader.py b/src/trace_link/chakra_device_trace_loader.py index 0737f132..3723c774 100644 --- a/src/trace_link/chakra_device_trace_loader.py +++ b/src/trace_link/chakra_device_trace_loader.py @@ -3,7 +3,14 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Tuple -from et_replay.utils import read_dictionary_from_json_file +try: + from et_replay.utils import read_dictionary_from_json_file +except ImportError as e: + if "et_replay" in str(e): + from .et_replay_import_error import get_et_replay_install_error_msg + + raise ImportError(get_et_replay_install_error_msg()) from None + raise from .kineto_operator import KinetoOperator diff --git a/src/trace_link/chakra_host_trace_loader.py b/src/trace_link/chakra_host_trace_loader.py index 8b2723b3..380a625b 100644 --- a/src/trace_link/chakra_host_trace_loader.py +++ b/src/trace_link/chakra_host_trace_loader.py @@ -2,8 +2,15 @@ import sys from typing import List -from et_replay.execution_trace import Node as PyTorchOperator -from et_replay.utils import load_execution_trace_file +try: + from et_replay.execution_trace import Node as PyTorchOperator + from et_replay.utils import load_execution_trace_file +except ImportError as e: + if "et_replay" in str(e): + from .et_replay_import_error import get_et_replay_install_error_msg + + raise ImportError(get_et_replay_install_error_msg()) from None + raise # Increase the recursion limit for deep Chakra host execution traces. sys.setrecursionlimit(10**6) diff --git a/src/trace_link/et_replay_import_error.py b/src/trace_link/et_replay_import_error.py new file mode 100644 index 00000000..5f9a1e8a --- /dev/null +++ b/src/trace_link/et_replay_import_error.py @@ -0,0 +1,15 @@ +"""Utility for handling et_replay import errors.""" + + +def get_et_replay_install_error_msg() -> str: + """ + Get the error message for missing et_replay installation. + + Returns + str: Error message with installation instructions. + """ + return ( + "Failed to import et_replay. et_replay is required but not packaged because it is not available as a PyPi package.\n\n" # noqa: E501. + "Please install it using:\n" + ' pip install "git+https://github.com/facebookresearch/param.git@7b19f586dd8b267333114992833a0d7e0d601630#subdirectory=et_replay"\n\n' # noqa: E501. + ) diff --git a/src/trace_link/kineto_operator.py b/src/trace_link/kineto_operator.py index f286fa6f..2293fe1e 100644 --- a/src/trace_link/kineto_operator.py +++ b/src/trace_link/kineto_operator.py @@ -1,6 +1,13 @@ from typing import Any, Dict, List, Optional -from et_replay.execution_trace import Node as PyTorchOperator +try: + from et_replay.execution_trace import Node as PyTorchOperator +except ImportError as e: + if "et_replay" in str(e): + from .et_replay_import_error import get_et_replay_install_error_msg + + raise ImportError(get_et_replay_install_error_msg()) from None + raise class KinetoOperator: diff --git a/src/trace_link/trace_linker.py b/src/trace_link/trace_linker.py index af247dd7..5b13c839 100644 --- a/src/trace_link/trace_linker.py +++ b/src/trace_link/trace_linker.py @@ -6,11 +6,19 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Optional, Tuple -from et_replay.execution_trace import ( - EXECUTION_TRACE_PROCESS_ANNOTATION, - EXECUTION_TRACE_THREAD_ANNOTATION, -) -from et_replay.execution_trace import Node as PyTorchOperator +try: + from et_replay.execution_trace import ( + EXECUTION_TRACE_PROCESS_ANNOTATION, + EXECUTION_TRACE_THREAD_ANNOTATION, + ) + from et_replay.execution_trace import Node as PyTorchOperator +except ImportError as e: + if "et_replay" in str(e): + from .et_replay_import_error import get_et_replay_install_error_msg + + raise ImportError(get_et_replay_install_error_msg()) from None + raise + from hta.analyzers.critical_path_analysis import CPEdgeType from hta.trace_analysis import TraceAnalysis diff --git a/tests/trace_link/test_chakra_host_trace_loader.py b/tests/trace_link/test_chakra_host_trace_loader.py index 19aeec8f..a47aa289 100644 --- a/tests/trace_link/test_chakra_host_trace_loader.py +++ b/tests/trace_link/test_chakra_host_trace_loader.py @@ -2,7 +2,15 @@ import pytest from chakra.src.trace_link.chakra_host_trace_loader import ChakraHostTraceLoader -from et_replay.execution_trace import Node as PyTorchOperator + +try: + from et_replay.execution_trace import Node as PyTorchOperator +except ImportError as e: + if "et_replay" in str(e): + from chakra.src.trace_link.et_replay_import_error import get_et_replay_install_error_msg + + raise ImportError(get_et_replay_install_error_msg()) from None + raise @pytest.fixture diff --git a/tests/trace_link/test_trace_linker.py b/tests/trace_link/test_trace_linker.py index 8430867e..bcf40e73 100644 --- a/tests/trace_link/test_trace_linker.py +++ b/tests/trace_link/test_trace_linker.py @@ -4,11 +4,20 @@ from chakra.src.trace_link.kineto_operator import KinetoOperator from chakra.src.trace_link.trace_linker import TraceLinker from chakra.src.trace_link.unique_id_assigner import UniqueIdAssigner -from et_replay.execution_trace import ( - EXECUTION_TRACE_PROCESS_ANNOTATION, - EXECUTION_TRACE_THREAD_ANNOTATION, -) -from et_replay.execution_trace import Node as PyTorchOperator + +try: + from et_replay.execution_trace import ( + EXECUTION_TRACE_PROCESS_ANNOTATION, + EXECUTION_TRACE_THREAD_ANNOTATION, + ) + from et_replay.execution_trace import Node as PyTorchOperator +except ImportError as e: + if "et_replay" in str(e): + from chakra.src.trace_link.et_replay_import_error import get_et_replay_install_error_msg + + raise ImportError(get_et_replay_install_error_msg()) from None + raise + from hta.analyzers.critical_path_analysis import CPEdgeType, CPGraph