Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions USER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ $ source chakra_env/bin/activate
With the virtual environment activated, install the Chakra package using pip.

```bash
# Install package from PyPi
$ pip install mlc-chakra

# Install package from source
$ pip install .

Expand All @@ -29,13 +32,11 @@ $ pip install https://github.com/mlcommons/chakra/archive/ae7c671db702eb1384015b
Installing PARAM is necessary for Chakra to function properly as it imports essential components from it.

```bash
$ git clone git@github.com:facebookresearch/param.git
$ cd param/et_replay
$ git checkout 7b19f586dd8b267333114992833a0d7e0d601630
$ pip install .
$ pip install "git+https://github.com/facebookresearch/param.git@7b19f586dd8b267333114992833a0d7e0d601630#subdirectory=et_replay"
```

### Step 4: Install Holistic Trace Analysis
Skip this step if you installed the PyPi package.
Installing Holistic Trace Analysis is necessary for Trace link.

```bash
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ requires = [ "setuptools>=61",
build-backend = "setuptools.build_meta"

[project]
name = "chakra"
name = "mlc-chakra"
requires-python = ">=3.7"
version = "1.0.0"
readme = "README.md"
Expand All @@ -19,7 +19,7 @@ dependencies = [
"graphviz",
"networkx",
"pydot",
"HolisticTraceAnalysis @ git+https://github.com/facebookresearch/HolisticTraceAnalysis.git@d731cc2e2249976c97129d409a83bd53d93051f6"
"HolisticTraceAnalysis <= 0.5.0",
]

[project.urls]
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ pyright==1.1.359
pytest==8.1.1
ruff==0.3.7
vulture==2.11
build
twine
9 changes: 8 additions & 1 deletion src/trace_link/chakra_device_trace_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple

from et_replay.utils import read_dictionary_from_json_file
try:
from et_replay.utils import read_dictionary_from_json_file
except ImportError as e:
if "et_replay" in str(e):
from .et_replay_import_error import get_et_replay_install_error_msg

raise ImportError(get_et_replay_install_error_msg()) from None
raise

from .kineto_operator import KinetoOperator

Expand Down
11 changes: 9 additions & 2 deletions src/trace_link/chakra_host_trace_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
import sys
from typing import List

from et_replay.execution_trace import Node as PyTorchOperator
from et_replay.utils import load_execution_trace_file
try:
from et_replay.execution_trace import Node as PyTorchOperator
from et_replay.utils import load_execution_trace_file
except ImportError as e:
if "et_replay" in str(e):
from .et_replay_import_error import get_et_replay_install_error_msg

raise ImportError(get_et_replay_install_error_msg()) from None
raise

# Increase the recursion limit for deep Chakra host execution traces.
sys.setrecursionlimit(10**6)
Expand Down
15 changes: 15 additions & 0 deletions src/trace_link/et_replay_import_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Utility for handling et_replay import errors."""


def get_et_replay_install_error_msg() -> str:
"""
Get the error message for missing et_replay installation.

Returns
str: Error message with installation instructions.
"""
return (
"Failed to import et_replay. et_replay is required but not packaged because it is not available as a PyPi package.\n\n" # noqa: E501.
"Please install it using:\n"
' pip install "git+https://github.com/facebookresearch/param.git@7b19f586dd8b267333114992833a0d7e0d601630#subdirectory=et_replay"\n\n' # noqa: E501.
)
9 changes: 8 additions & 1 deletion src/trace_link/kineto_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import Any, Dict, List, Optional

from et_replay.execution_trace import Node as PyTorchOperator
try:
from et_replay.execution_trace import Node as PyTorchOperator
except ImportError as e:
if "et_replay" in str(e):
from .et_replay_import_error import get_et_replay_install_error_msg

raise ImportError(get_et_replay_install_error_msg()) from None
raise


class KinetoOperator:
Expand Down
18 changes: 13 additions & 5 deletions src/trace_link/trace_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional, Tuple

from et_replay.execution_trace import (
EXECUTION_TRACE_PROCESS_ANNOTATION,
EXECUTION_TRACE_THREAD_ANNOTATION,
)
from et_replay.execution_trace import Node as PyTorchOperator
try:
from et_replay.execution_trace import (
EXECUTION_TRACE_PROCESS_ANNOTATION,
EXECUTION_TRACE_THREAD_ANNOTATION,
)
from et_replay.execution_trace import Node as PyTorchOperator
except ImportError as e:
if "et_replay" in str(e):
from .et_replay_import_error import get_et_replay_install_error_msg

raise ImportError(get_et_replay_install_error_msg()) from None
raise

from hta.analyzers.critical_path_analysis import CPEdgeType
from hta.trace_analysis import TraceAnalysis

Expand Down
10 changes: 9 additions & 1 deletion tests/trace_link/test_chakra_host_trace_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

import pytest
from chakra.src.trace_link.chakra_host_trace_loader import ChakraHostTraceLoader
from et_replay.execution_trace import Node as PyTorchOperator

try:
from et_replay.execution_trace import Node as PyTorchOperator
except ImportError as e:
if "et_replay" in str(e):
from chakra.src.trace_link.et_replay_import_error import get_et_replay_install_error_msg

raise ImportError(get_et_replay_install_error_msg()) from None
raise


@pytest.fixture
Expand Down
19 changes: 14 additions & 5 deletions tests/trace_link/test_trace_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@
from chakra.src.trace_link.kineto_operator import KinetoOperator
from chakra.src.trace_link.trace_linker import TraceLinker
from chakra.src.trace_link.unique_id_assigner import UniqueIdAssigner
from et_replay.execution_trace import (
EXECUTION_TRACE_PROCESS_ANNOTATION,
EXECUTION_TRACE_THREAD_ANNOTATION,
)
from et_replay.execution_trace import Node as PyTorchOperator

try:
from et_replay.execution_trace import (
EXECUTION_TRACE_PROCESS_ANNOTATION,
EXECUTION_TRACE_THREAD_ANNOTATION,
)
from et_replay.execution_trace import Node as PyTorchOperator
except ImportError as e:
if "et_replay" in str(e):
from chakra.src.trace_link.et_replay_import_error import get_et_replay_install_error_msg

raise ImportError(get_et_replay_install_error_msg()) from None
raise

from hta.analyzers.critical_path_analysis import CPEdgeType, CPGraph


Expand Down
Loading