Refa: PARALLEL_DEVICES is a static parameter. (#6168)

### What problem does this PR solve?


### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2025-03-17 16:49:54 +08:00
committed by GitHub
parent 45fe02c8b3
commit 3a99c2b5f4
6 changed files with 29 additions and 28 deletions

View File

@ -128,8 +128,8 @@ class Docx(DocxParser):
class Pdf(PdfParser):
def __init__(self, parallel_devices = None):
super().__init__(parallel_devices)
def __init__(self):
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
@ -197,7 +197,7 @@ class Markdown(MarkdownParser):
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, parallel_devices=None, **kwargs):
lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, excel, txt.
This method apply the naive ways to chunk files.
@ -237,7 +237,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf(parallel_devices)
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,

View File

@ -39,6 +39,13 @@ SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
PAGERANK_FLD = "pagerank_fea"
TAG_FLD = "tag_feas"
PARALLEL_DEVICES = None
try:
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch'")
def print_rag_settings():
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")

View File

@ -100,13 +100,6 @@ MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDER
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
PARALLEL_DEVICES = None
try:
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch'")
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
@ -249,7 +242,7 @@ async def build_chunks(task, progress_callback):
try:
async with chunk_limiter:
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], parallel_devices = PARALLEL_DEVICES, callback=progress_callback,
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException: