diff --git a/configs/mixed_landmark_0814_no_extend_qsa.json b/configs/mixed_landmark_0814_no_extend_qsa.json index 1fc11ac2..c69ea1e5 100644 --- a/configs/mixed_landmark_0814_no_extend_qsa.json +++ b/configs/mixed_landmark_0814_no_extend_qsa.json @@ -7,24 +7,74 @@ "__delta_attention_args": "window_0-diff_1-w_16-dense_decode-smooth", "using_extend": false, "dense_layers": [0, 1, 2, 47, 46, 45], - "mask_refresh_interval": [96], + "mask_refresh_interval": [96, 32, 16], "layers": [ { "sliding_window_size": 1024, "sliding_window_size_for_masking_step": [1024, 1024, 1024], - "second_stage_k": 1024, + "second_stage_k": 2048, "sink_token_size": 1024, "sa_extend_backend": "self_extend", - "stages": [ { } ] + "stages": [ + { + "stage_block_size_q":128, + "stage_block_stride_q":4, + "stage_chunk_size":256, + "stage_k":null, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":32, + "stage_k":65536, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":8, + "stage_k":8192, + "stage_stride":1, + "using_landmark":false + } + ] }, { "sliding_window_size": 1024, "sliding_window_size_for_masking_step": [1024, 1024, 1024], - "second_stage_k": 1024, + "second_stage_k": 2048, "sink_token_size": 1024, "sa_extend_backend": "self_extend", "scan_extend_backend": "none", - "stages": [ { } ] + "stages": [ + { + "stage_block_size_q":128, + "stage_block_stride_q":4, + "stage_chunk_size":256, + "stage_k":null, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":32, + "stage_k":65536, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":8, + "stage_k":8192, + "stage_stride":1, + "using_landmark":false + } + ] } ], "prefill_layers": [ diff --git a/scripts/bench_latency_paged_attn.py b/scripts/bench_latency_paged_attn.py new file mode 100644 index 00000000..eda3e2a9 --- /dev/null +++ b/scripts/bench_latency_paged_attn.py @@ -0,0 +1,201 @@ +""" +python scripts/benchmark_latency_paged_attn.py +""" + +import os +import json +import traceback +import pandas as pd +import torch +from transformers import AutoConfig +import triton +from hip_attn.v1_2.paged_hip import forward_paged_hip, HiPAttentionConfig + +def forward_seq_len( + dtype: torch.dtype, + seq_len: int, + q_head: int, + kv_head: int, + head_dim: int, + hip_config: HiPAttentionConfig, + batch_size: int = 1, +): + device = torch.device("cuda:0") + + query = torch.rand( + (batch_size * seq_len, q_head, head_dim), + dtype=torch.bfloat16, + device=device + ) + k_cache = torch.rand( + # NOTE: + 1 is special behavior on SGlang. I am not sure about it is exists in vLLM too. + ((seq_len + 1) * batch_size, kv_head, head_dim), + dtype=torch.bfloat16, + device=device, + ).to(dtype) + v_cache = k_cache.clone() + positions = torch.arange(0, batch_size * seq_len, dtype=torch.long, device=device) + positions = positions % seq_len + seq_lens = torch.zeros((batch_size,), dtype=torch.long, device=device) + seq_lens[:] = seq_len + block_table = torch.arange(0, batch_size * seq_len, dtype=torch.long, device=device) + block_table = block_table.view(batch_size, seq_len) + layer_id = 10 + logit_cap = None + orig_context_length = seq_len + max_context_length = seq_len + is_kv_cache_offload_enable = False + rope_range = (0, head_dim) + extend_prefix_lens_cpu = [0,] * batch_size + extend_seq_lens_cpu = [seq_len,] * batch_size + + torch.cuda.synchronize() + + start = torch.cuda.Event(True) + end = torch.cuda.Event(True) + + start.record() + + forward_paged_hip( + query=query, + sm_scale=1 / (head_dim ** 0.5), + batch_size=batch_size, + k_cache=k_cache, + v_cache=v_cache, + offload_cache=None, + positions=positions, + seq_lens=seq_lens, + req_to_tokens=None, + req_pool_indices=None, + block_table=block_table, + rope_cos=None, + rope_sin=None, + layer_id=layer_id, + logit_cap=logit_cap, + orig_context_len=orig_context_length, + max_context_len=max_context_length, + hip_config=hip_config, + is_kv_cache_offload_enabled=is_kv_cache_offload_enable, + rope_range=rope_range, + extend_prefix_lens_cpu=extend_prefix_lens_cpu, + extend_seq_lens_cpu=extend_seq_lens_cpu, + is_decode=False, + ) + + end.record() + end.synchronize() + return start.elapsed_time(end) + +def try_set_environ(name: str, value): + if name in os.environ: + return + os.environ[name] = value + +def evaluate_autotune( + sa_block_size: int, + bsa_block_k: int, + hip_config: HiPAttentionConfig, +): + model_name = "Qwen/Qwen3-235B-A22B-Instruct-2507" + + try_set_environ("BSA_K", "32") + try_set_environ("BSA_EXACT_K", "32") + try_set_environ("BSA_BLOCK_K", str(bsa_block_k)) + try_set_environ("HIP_DEBUG_DELTA_QSA", "1") + try_set_environ("HIP_DEBUG_RECOMPUTE_SPLIT", "0") + try_set_environ("TRITON_PRINT_AUTOTUNING", "1") + try_set_environ("SA_BLOCK_SIZE", str(sa_block_size)) + try_set_environ("SA_DECODE_BLOCK_SIZE", "128") + try_set_environ("HIP_DISABLE_AUTOTUNE", "0") + + n_warmup = 3 + n_measure = 100 + n_tp = 8 + + config = AutoConfig.from_pretrained(model_name) + q_head = config.num_attention_heads + q_head = triton.cdiv(q_head, n_tp) + kv_head = config.num_key_value_heads + kv_head = triton.cdiv(kv_head, n_tp) + head_dim = config.hidden_size // config.num_attention_heads + + seq_lens = [32, 64, 128, 256, 384, 512, 768, 1024] + dtypes = [torch.bfloat16, torch.float8_e5m2] + + data = [] + + for seq_len in seq_lens: + for dtype in dtypes: + for _ in range(n_warmup): + try: + forward_seq_len( + dtype, + seq_len * 1024, + q_head, + kv_head, + head_dim, + hip_config=hip_config, + ) + except Exception: + traceback.print_exc() + + latencies = [] + for _ in range(n_measure): + try: + latency = forward_seq_len( + dtype, + seq_len * 1024, + q_head, + kv_head, + head_dim, + hip_config=hip_config, + ) + exception = "" + except Exception: + latency = float("nan") + exception = traceback.format_exc() + latencies.append(latency) + latency = sum(latencies) / len(latencies) + + data_point = { + "dtype": str(dtype), + "seq_len": seq_len, + "bsa_block_k": bsa_block_k, + "sa_block_size": sa_block_size, + "model": model_name, + "latency": latency, + "exception": exception, + } + print(data_point, flush=True) + data.append(data_point) + + return data + +def main(): + hip_config = HiPAttentionConfig( + json_or_path="./configs/mixed_landmark_0814_no_extend_qsa.json", + json_override='{"__seq_thresh_fa3": 0}' + ) + + bsa_block_ks = [64, 32] + sa_block_sizes = [256, 128, 64] + + data = [] + + for bsa_block_k in bsa_block_ks: + for sa_block_size in sa_block_sizes: + data.extend(evaluate_autotune( + bsa_block_k=bsa_block_k, + sa_block_size=sa_block_size, + hip_config=hip_config, + )) + + os.makedirs("saves/bench_latency_paged_attn", exist_ok=True) + with open("saves/bench_latency_paged_attn/measures.json", "w") as f: + json.dump(data, f) + + df = pd.DataFrame(data) + df.to_csv("saves/bench_latency_paged_attn/measures.csv") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hip_attn/utils/sglang_watchdog.py b/src/hip_attn/utils/sglang_watchdog.py new file mode 100644 index 00000000..0baca976 --- /dev/null +++ b/src/hip_attn/utils/sglang_watchdog.py @@ -0,0 +1,148 @@ +import argparse +import datetime +import sys +import os +import subprocess +import threading +import time +import traceback +import requests + +def log(*args): + comment = " ".join([str(a) for a in args]) + timestamp = "{:%Y-%m-%d %H:%M:%S}".format(datetime.datetime.now()) + print(f"\033[91m[{timestamp} sglang_watchdog] {comment}\033[0m", flush=True) + +class Watchdog: + def __init__( + self, + ): + self.timeout_bootup = 600 + self.timeout_tick = 60 + self.sleep_step = 1 + self.proc: subprocess.Popen = None + self.argv: list[str] = None + self.running: bool = True + + def start_subprocess(self): + args = [ + "python", + "-m", + "sglang.launch_server", + *self.argv + ] + flatten_args = " ".join(args) + log(f"Start subprocess using following command: {flatten_args}") + self.proc = subprocess.Popen(args) + log(f"Start subprocess communication.") + return_code = self.proc.wait() + log(f"Return code is {return_code}") + + def kill_subprocess(self): + log(f"Start kill subprocess") + if self.proc is not None: + self.proc.kill() + self.proc = None + subprocess.call(["pkill", "sglang"]) + log(f"Finish kill subprocess") + + def wait_for_health(self, timeout: int): + response = requests.get(self.health_endpoint, timeout=timeout) + response.raise_for_status() + + def main_watchdog(self): + while True: + try: + t_boot = time.time() + booted = False + while self.proc is None: + log("Watchdog is waiting for process started...") + time.sleep(self.sleep_step) + while ( + (time.time() - t_boot) < self.timeout_bootup + and self.proc.returncode is None + and not booted + ): + try: + self.wait_for_health(timeout=self.timeout_bootup) + log("Server booted successfully.") + booted = True + except (TimeoutError, requests.HTTPError, requests.ConnectionError): + # NOTE: may process is not started yet + pass + time.sleep(self.sleep_step) + + if not booted: raise TimeoutError() + + while True: + log("Try watch dog.") + self.wait_for_health(timeout=self.timeout_tick) + log("Done watch dog successfully.") + time.sleep(self.timeout_tick) + + except (TimeoutError, requests.HTTPError): + self.kill_subprocess() + except Exception as ex: + trace = traceback.format_exc() + log(f"Traceback:\n{trace}") + log(f"Unexpected error on watchdog thread: {ex}") + self.kill_subprocess() + + time.sleep(self.sleep_step) + + def main_starter(self): + while True: + self.start_subprocess() + time.sleep(self.sleep_step) + + def start(self): + try: + if "--" in sys.argv: + my_args = sys.argv[1:sys.argv.index("--")] + argv = sys.argv[sys.argv.index("--") + 1:] + else: + my_args = [] + argv = sys.argv[1:] + + parser = argparse.ArgumentParser() + parser.add_argument("--timeout-bootup", default=self.timeout_bootup, type=int) + parser.add_argument("--timeout", default=self.timeout_tick, type=int) + parser.add_argument("--sleep-step", default=self.sleep_step, type=int) + + args = parser.parse_args(my_args) + self.timeout_bootup = args.timeout_bootup + self.timeout_tick = args.timeout + self.sleep_step = args.sleep_step + + assert "--host" in argv + assert "--port" in argv + self.host = argv[argv.index("--host") + 1] + self.port = argv[argv.index("--port") + 1] + self.health_endpoint = f"http://{self.host}:{self.port}/health" + log(f"Watching: {self.health_endpoint}") + + self.argv = argv + + self.thread_watchdog = threading.Thread( + target=self.main_watchdog, + daemon=True + ) + self.thread_starter = threading.Thread( + target=self.main_starter, + daemon=True + ) + + self.thread_starter.start() + time.sleep(self.sleep_step) + self.thread_watchdog.start() + + self.thread_watchdog.join() + self.thread_starter.join() + + self.running = False + except KeyboardInterrupt: + self.kill_subprocess() + +if __name__ == '__main__': + dog = Watchdog() + dog.start() \ No newline at end of file diff --git a/src/hip_attn/v1_2/attention_extend.py b/src/hip_attn/v1_2/attention_extend.py index 36ddf535..449c2c02 100644 --- a/src/hip_attn/v1_2/attention_extend.py +++ b/src/hip_attn/v1_2/attention_extend.py @@ -16,7 +16,12 @@ from hip_attn.utils.rope import adjust_rope from hip_attn.v1_2.attention_decode_bsa import decode_block_sparse_attention from hip_attn.v1_2.attention_extend_bsa import block_sparse_attention -from hip_attn.v1_2.attention_extend_bsa_tilelang import block_sparse_attention_tilelang + +try: + from hip_attn.v1_2.attention_extend_bsa_tilelang import block_sparse_attention_tilelang +except (ImportError, OSError): + block_sparse_attention_tilelang = None + from hip_attn.v1_2.attention_metadata import ( EnsembleScoreStage, EvalScoreStage, diff --git a/src/hip_attn/v1_2/paged_hip.py b/src/hip_attn/v1_2/paged_hip.py index 58fca4b9..7df14fa3 100644 --- a/src/hip_attn/v1_2/paged_hip.py +++ b/src/hip_attn/v1_2/paged_hip.py @@ -2,17 +2,85 @@ import math import os import warnings -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import cv2 import numba import numpy as np import torch import triton -from flash_attn import flash_attn_func from matplotlib import pyplot as plt -from sgl_kernel.flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func -from sgl_kernel.flash_attn import flash_attn_with_kvcache + +try: + from flash_attn import flash_attn_func +except ImportError: + flash_attn_func = None + +try: + from sgl_kernel.flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func + from sgl_kernel.flash_attn import flash_attn_with_kvcache + IS_AMD = False +except ImportError: + # FIXME: better AMD detection algorithm + IS_AMD = True + + from flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func + from flash_attn import flash_attn_with_kvcache as __flash_attn_with_kvcache + + def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + sinks=None, + ver=3, + ): + return __flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=k, + v=v, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + block_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, # -1 means infinite context window + softcap=softcap, # 0.0 means deactivated + rotary_interleaved=rotary_interleaved, + alibi_slopes=None, + num_splits=num_splits, + return_softmax_lse=return_softmax_lse, + ) from hip_attn.v1_2.hip_config import HiPAttentionConfig from hip_attn.v1_2.utils import capture @@ -130,9 +198,14 @@ def flash_attn_varlen_func( split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, + get_tp_group, ) - SGLANG_DIST_ACTIVATED = True + try: + get_tp_group() + SGLANG_DIST_ACTIVATED = True + except AssertionError: + SGLANG_DIST_ACTIVATED = False except ImportError as ex: SGLANG_DIST_ACTIVATED = False @@ -365,8 +438,11 @@ def forward_paged_hip( positions=positions[start_len : start_len + seq_len], seq_lens=seq_lens[idx_batch : idx_batch + 1], req_to_tokens=req_to_tokens, - req_pool_indices=req_pool_indices[idx_batch : idx_batch + 1], - block_table=None, + req_pool_indices=( + req_pool_indices[idx_batch : idx_batch + 1] + if req_pool_indices is not None else None + ), + block_table=block_table, rope_cos=rope_cos, rope_sin=rope_sin, rope_range=rope_range, diff --git a/src/hip_attn/v1_2/query_sparse_attention.py b/src/hip_attn/v1_2/query_sparse_attention.py index d001e213..c862ba1f 100644 --- a/src/hip_attn/v1_2/query_sparse_attention.py +++ b/src/hip_attn/v1_2/query_sparse_attention.py @@ -1907,7 +1907,7 @@ def forward( assert rope_cos.ndim == 2 assert extend_backend in ["self_extend", "nope"] - if rope_sin is not None: + if (rope_sin is not None) and (extend_backend in ["self_extend"]): HEAD_DIM_K_ROPE = rope_sin.shape[-1] HEAD_DIM_K_NOPE = HEAD_DIM_K - HEAD_DIM_K_ROPE else: