"""Agentic strategy and query generation using ``output_type``.
Produces a :class:`QueryPlan` with structured :class:`GeneratedQuery` items.
"""
import os
from textwrap import dedent
from typing import List, Literal
from shared.llm import get_agent_model
from shared.logging import get_logger
from .models import GeneratedQuery, PipelineTask, QueryPlan
from .utils import retry_async
logger = get_logger(__name__)
SourceLiteral = Literal["arxiv", "scholar", "pubmed", "github"]
def _get_strategy_agent():
"""Lazy initialization of the strategy agent."""
from agents import Agent
return Agent(
name="Query Strategist",
model=get_agent_model(),
instructions=dedent(
"""
You turn a user task into a compact set of search queries. For EACH query,
you must also choose the most relevant source among: arXiv, Google Scholar,
PubMed, GitHub.
- Prefer concise keyword-style queries
- Avoid redundancy between queries
- Provide a short rationale per query
- If source=arXiv, boolean-style with AND/OR/NOT is welcome; optional category constraints may apply
- If source=PubMed, prefer biomedical terms and common synonyms
- If source=GitHub, qualifiers like language:Python, stars:>100 are welcome
- Keep the set small and high-precision
- Output JSON matching the provided schema, including the "source" field per query
"""
),
output_type=QueryPlan,
)
[docs]
async def generate_query_plan(task: PipelineTask) -> QueryPlan:
"""Invoke the strategy agent and return a structured query plan.
Falls back gracefully to a deterministic heuristic if the agent returns
nothing or invalid output.
:param task: The pipeline task describing user intent and constraints.
:returns: A :class:`QueryPlan` with up to ``task.max_queries`` queries.
"""
# Provide compact JSON-like prompt with optional user-suggested queries
import json
payload = {
"task": task.query,
"categories": task.categories or [],
"max_queries": task.max_queries,
"suggested_queries": task.queries or [],
"allowed_sources": ["arxiv", "scholar", "pubmed", "github"],
}
prompt = json.dumps(payload)
logger.debug(
f"Generating query plan (max={task.max_queries}, categories={task.categories})"
)
use_agents = os.getenv("PIPELINE_USE_AGENTS_STRATEGY", "1").lower() in {
"1",
"true",
"yes",
}
if not use_agents:
logger.info("Strategy agent disabled via env; using heuristic queries")
raise Exception("strategy_agent_disabled")
try:
logger.info("Making a call to the strategy agent...")
from agents import Runner
result = await retry_async(lambda: Runner.run(_get_strategy_agent(), prompt))
plan_obj: QueryPlan = result.final_output
num_q = len(plan_obj.queries) if getattr(plan_obj, "queries", None) else 0
logger.info(f"Strategy agent produced {num_q} queries")
if not getattr(plan_obj, "queries", None):
raise ValueError("Empty plan")
# Ensure source is present for each query; if missing, apply heuristic
for q in plan_obj.queries:
if not getattr(q, "source", None):
# Heuristic fallback per query
text = (q.query_text or "").lower()
if any(
k in text
for k in [
"clinical",
"biomedical",
"gene",
"protein",
"cancer",
"pubmed",
]
):
q.source = "pubmed"
elif any(
k in text
for k in [
"github",
"code",
"implementation",
"repo",
"repository",
"stars:",
]
):
q.source = "github"
elif any(
k in text
for k in ["survey", "review", "meta-analysis", "literature"]
):
q.source = "scholar"
else:
q.source = "arxiv"
plan_obj.queries = plan_obj.queries[: task.max_queries]
logger.debug(
"Queries: " + "; ".join(q.query_text for q in plan_obj.queries[:5])
)
return plan_obj
except Exception as error:
logger.warning(f"Strategy agent failed, using heuristic fallback: {error}")
base: str = task.query.strip()
def _infer_source(text: str) -> SourceLiteral:
t = text.lower()
if any(
k in t
for k in [
"clinical",
"biomedical",
"gene",
"protein",
"cancer",
"pubmed",
]
):
return "pubmed"
if any(
k in t
for k in [
"github",
"code",
"implementation",
"repo",
"repository",
"stars:",
]
):
return "github"
if any(k in t for k in ["survey", "review", "meta-analysis", "literature"]):
return "scholar"
return "arxiv"
base_arxiv = GeneratedQuery(
query_text=base,
source=_infer_source(base),
rationale="Direct match to task",
)
survey_q = GeneratedQuery(
query_text=f"{base} AND (survey OR review)",
source=_infer_source(base + " survey"),
rationale="Surveys and reviews",
)
artifacts_q = GeneratedQuery(
query_text=f"{base} AND (benchmark OR dataset OR code)",
source=_infer_source(base + " code"),
rationale="Practical artifacts",
)
exclude_theory_q = GeneratedQuery(
query_text=f"{base} NOT theory-only",
source=_infer_source(base),
rationale="Exclude purely theoretical work",
)
queries: List[GeneratedQuery] = [
base_arxiv,
survey_q,
artifacts_q,
exclude_theory_q,
]
fallback = QueryPlan(notes=None, queries=queries[: task.max_queries])
logger.info(f"Heuristic produced {len(fallback.queries)} queries")
return fallback