Powered by:NEFU AB-IN
3213. 最小代价构造字符串
题意
给你一个字符串 target、一个字符串数组 words 以及一个整数数组 costs,这两个数组长度相同。
设想一个空字符串 s。
你可以执行以下操作任意次数(包括零次):
选择一个在范围 [0, words.length - 1] 的索引 i。
将 words[i] 追加到 s。
该操作的成本是 costs[i]。
返回使 s 等于 target 的 最小 成本。如果不可能,返回 -1。
思路
字典树/字符串哈希 + dp
- 使用 Trie 树存储单词和成本:
我们将所有的单词和对应的成本插入到一个 Trie 树中。Trie 树是一种多叉树,可以快速查找以某个前缀开头的所有单词。
这样我们就能在 Trie 树中快速查找到以 target 中某个位置开始的所有前缀单词及其成本。 - 动态规划(Dynamic Programming):
使用一个动态规划数组 dp,其中 dp[i] 表示构造 target 的前 i 个字符的最小成本。
初始化 dp[0] = 0,表示构造空字符串的成本为 0,其他位置初始化为无穷大,表示尚未计算到该位置。 - 遍历目标字符串:
对于目标字符串 target 的每一个位置 i,如果 dp[i] 是无穷大,表示不能从当前位置开始构造,则跳过。
否则,使用 Trie 树的 search 方法,从当前位置 i 开始查找所有可能的前缀及其成本。
对于每一个找到的前缀,更新 dp 数组:dp[i + length] = min(dp[i + length], dp[i] + cost),表示从当前位置 i 开始构造到 i + length 的最小成本。
代码
# 3.8.19 import
import random
from collections import Counter, defaultdict, deque
from datetime import datetime, timedelta
from functools import lru_cache
from heapq import heapify, heappop, heappush, nlargest, nsmallest
from itertools import combinations, compress, permutations, starmap, tee
from math import ceil, fabs, floor, gcd, log, sqrt
from string import ascii_lowercase, ascii_uppercase
from sys import exit, setrecursionlimit, stdin
from typing import Any, Dict, List, Tuple, TypeVar, Union
# Constants
TYPE = TypeVar('TYPE')
N = int(2e5 + 10) # If using AR, modify accordingly
M = int(20) # If using AR, modify accordingly
INF = int(2e9)
OFFSET = int(100)
# Set recursion limit
setrecursionlimit(INF)
class Arr:
array = staticmethod(lambda x=0, size=N: [x] * size)
array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)])
graph = staticmethod(lambda size=N: [[] for _ in range(size)])
@staticmethod
def to_1_indexed(data: Union[List, str, List[List]]):
"""Adds a zero prefix to the data and returns the modified data and its length."""
if isinstance(data, list):
if all(isinstance(item, list) for item in data): # Check if it's a 2D array
new_data = [[0] * (len(data[0]) + 1)] + [[0] + row for row in data]
return new_data, len(new_data) - 1, len(new_data[0]) - 1
else:
new_data = [0] + data
return new_data, len(new_data) - 1
elif isinstance(data, str):
new_data = '0' + data
return new_data, len(new_data) - 1
else:
raise TypeError("Input must be a list, a 2D list, or a string")
class Str:
letter_to_num = staticmethod(lambda x: ord(x.upper()) - 65) # A -> 0
num_to_letter = staticmethod(lambda x: ascii_uppercase[x]) # 0 -> A
removeprefix = staticmethod(lambda s, prefix: s[len(prefix):] if s.startswith(prefix) else s)
removesuffix = staticmethod(lambda s, suffix: s[:-len(suffix)] if s.endswith(suffix) else s)
class Math:
max = staticmethod(lambda a, b: a if a > b else b)
min = staticmethod(lambda a, b: a if a < b else b)
class IO:
input = staticmethod(lambda: stdin.readline().rstrip("\r\n"))
read = staticmethod(lambda: map(int, IO.input().split()))
read_list = staticmethod(lambda: list(IO.read()))
class Std:
@staticmethod
def find(container: Union[List[TYPE], str], value: TYPE):
"""Returns the index of value in container or -1 if value is not found."""
if isinstance(container, list):
try:
return container.index(value)
except ValueError:
return -1
elif isinstance(container, str):
return container.find(value)
@staticmethod
def pairwise(iterable):
"""Return successive overlapping pairs taken from the input iterable."""
a, b = tee(iterable)
next(b, None)
return zip(a, b)
@staticmethod
def bisect_left(a, x, key=lambda y: y):
"""The insertion point is the first position where the element is not less than x."""
left, right = 0, len(a)
while left < right:
mid = (left + right) >> 1
if key(a[mid]) < x:
left = mid + 1
else:
right = mid
return left
@staticmethod
def bisect_right(a, x, key=lambda y: y):
"""The insertion point is the first position where the element is greater than x."""
left, right = 0, len(a)
while left < right:
mid = (left + right) >> 1
if key(a[mid]) <= x:
left = mid + 1
else:
right = mid
return left
class SparseTable:
def __init__(self, data: list, func=lambda x, y: x | y):
"""Initialize the Sparse Table with the given data and function."""
self.func = func
self.st = [list(data)]
i, n = 1, len(self.st[0])
while 2 * i <= n:
pre = self.st[-1]
self.st.append([func(pre[j], pre[j + i]) for j in range(n - 2 * i + 1)])
i <<= 1
def query(self, begin: int, end: int):
"""Query the combined result over the interval [begin, end]."""
lg = (end - begin + 1).bit_length() - 1
return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])
class TrieNode:
def __init__(self):
"""Initialize children dictionary and cost. The trie tree is a 26-ary tree."""
self.children = {}
self.cost = INF
def add(self, word, cost):
"""Add a word to the trie with the associated cost."""
node = self
for c in word:
if c not in node.children:
node.children[c] = Std.TrieNode()
node = node.children[c]
node.cost = min(node.cost, cost)
def search(self, word):
"""Search for prefixes of 'word' in the trie and return their lengths and costs."""
node = self
ans = []
for i, c in enumerate(word):
if c not in node.children:
break
node = node.children[c]
if node.cost != INF:
ans.append([i + 1, node.cost]) # i + 1 to denote length from start
return ans
class StringHash:
def __init__(self, s: str, mod: int = 1_070_777_777):
"""Initialize the StringHash object with the string, base, and mod."""
self.s = s
self.mod = mod
self.base = random.randint(8 * 10 ** 8, 9 * 10 ** 8)
self.n = len(s)
self.pow_base = [1] + Arr.array(0, self.n) # pow_base[i] = BASE^i
self.pre_hash = Arr.array(0, self.n + 1) # pre_hash[i] = hash(s[:i])
self._compute_hash()
def _compute_hash(self):
"""Compute the prefix hash values and power of base values for the string."""
for i, b in enumerate(self.s):
self.pow_base[i + 1] = self.pow_base[i] * self.base % self.mod
self.pre_hash[i + 1] = (self.pre_hash[i] * self.base + ord(b)) % self.mod
def get_sub_hash(self, l: int, r: int) -> int:
"""Get the hash value of the substring s[l:r+1] """
return (self.pre_hash[r + 1] - self.pre_hash[l] * self.pow_base[r - l + 1] % self.mod + self.mod) % self.mod
def get_full_hash(self) -> int:
"""Get the hash value of the full string"""
return self.pre_hash[self.n]
def compute_hash(self, word: str) -> int:
"""Compute the hash value of a given word using the object's base and mod."""
h = 0
for b in word:
h = (h * self.base + ord(b)) % self.mod
return h
# ————————————————————— Division line ——————————————————————
class Solution:
def minimumCost(self, target: str, words: List[str], costs: List[int]) -> int:
# Build the Trie
trie = Std.TrieNode()
for word, cost in zip(words, costs):
trie.add(word, cost)
n = len(target)
dp = Arr.array(INF, n + 1)
dp[0] = 0
# Dynamic programming to calculate the minimum cost
for i in range(n):
if dp[i] == INF:
continue
for length, cost in trie.search(target[i:]):
dp[i + length] = min(dp[i + length], dp[i] + cost)
return dp[n] if dp[n] != INF else -1
class Solution:
def minimumCost(self, target: str, words: List[str], costs: List[int]) -> int:
n = len(target)
target_hash = Std.StringHash(target)
# 每个 words[i] 的哈希值 -> 最小成本
min_cost = defaultdict(lambda: INF)
for w, c in zip(words, costs):
h = target_hash.compute_hash(w)
min_cost[h] = min(min_cost[h], c)
# 获取所有唯一的单词长度
sorted_lens = sorted(set(map(len, words)))
dp = Arr.array(INF, n + 1)
dp[0] = 0
for i in range(n):
if dp[i] == INF:
continue
for sz in sorted_lens:
if i + sz > n:
break
# 计算子串 target[i:i+sz] 的哈希值
sub_hash = target_hash.get_sub_hash(i, i + sz - 1)
dp[i + sz] = min(dp[i + sz], dp[i] + min_cost[sub_hash])
return -1 if dp[n] == INF else dp[n]