wip: add getExtensionField

This commit is contained in:
Philipp Kühn
2021-04-15 21:14:33 +02:00
parent d194b90a61
commit 07bc40ce75
11 changed files with 213 additions and 274 deletions

View File

@@ -4,8 +4,8 @@ import { inputRules as inputRulesPlugin } from 'prosemirror-inputrules'
import { EditorView, Decoration } from 'prosemirror-view'
import { Plugin } from 'prosemirror-state'
import { Editor } from './Editor'
import { Extensions, NodeViewRenderer, RawCommands } from './types'
import createExtensionContext from './helpers/createExtensionContext'
import { Extensions, RawCommands, AnyConfig } from './types'
import getExtensionField from './helpers/getExtensionField'
import getSchema from './helpers/getSchema'
import getSchemaTypeByName from './helpers/getSchemaTypeByName'
import getNodeType from './helpers/getNodeType'
@@ -13,6 +13,7 @@ import splitExtensions from './helpers/splitExtensions'
import getAttributesFromExtensions from './helpers/getAttributesFromExtensions'
import getRenderedAttributes from './helpers/getRenderedAttributes'
import callOrReturn from './utilities/callOrReturn'
import { NodeConfig } from '.'
export default class ExtensionManager {
@@ -30,51 +31,51 @@ export default class ExtensionManager {
this.schema = getSchema(this.extensions)
this.extensions.forEach(extension => {
const context = createExtensionContext(extension, {
const context = {
options: extension.options,
editor: this.editor,
type: getSchemaTypeByName(extension.config.name, this.schema),
})
}
if (extension.type === 'mark') {
const keepOnSplit = callOrReturn(extension.config.keepOnSplit, context) ?? true
const keepOnSplit = callOrReturn(getExtensionField(extension, 'keepOnSplit', context)) ?? true
if (keepOnSplit) {
this.splittableMarks.push(extension.config.name)
}
}
if (typeof extension.config.onBeforeCreate === 'function') {
this.editor.on('beforeCreate', extension.config.onBeforeCreate.bind(context))
}
// if (typeof extension.config.onBeforeCreate === 'function') {
// this.editor.on('beforeCreate', extension.config.onBeforeCreate.bind(context))
// }
if (typeof extension.config.onCreate === 'function') {
this.editor.on('create', extension.config.onCreate.bind(context))
}
// if (typeof extension.config.onCreate === 'function') {
// this.editor.on('create', extension.config.onCreate.bind(context))
// }
if (typeof extension.config.onUpdate === 'function') {
this.editor.on('update', extension.config.onUpdate.bind(context))
}
// if (typeof extension.config.onUpdate === 'function') {
// this.editor.on('update', extension.config.onUpdate.bind(context))
// }
if (typeof extension.config.onSelectionUpdate === 'function') {
this.editor.on('selectionUpdate', extension.config.onSelectionUpdate.bind(context))
}
// if (typeof extension.config.onSelectionUpdate === 'function') {
// this.editor.on('selectionUpdate', extension.config.onSelectionUpdate.bind(context))
// }
if (typeof extension.config.onTransaction === 'function') {
this.editor.on('transaction', extension.config.onTransaction.bind(context))
}
// if (typeof extension.config.onTransaction === 'function') {
// this.editor.on('transaction', extension.config.onTransaction.bind(context))
// }
if (typeof extension.config.onFocus === 'function') {
this.editor.on('focus', extension.config.onFocus.bind(context))
}
// if (typeof extension.config.onFocus === 'function') {
// this.editor.on('focus', extension.config.onFocus.bind(context))
// }
if (typeof extension.config.onBlur === 'function') {
this.editor.on('blur', extension.config.onBlur.bind(context))
}
// if (typeof extension.config.onBlur === 'function') {
// this.editor.on('blur', extension.config.onBlur.bind(context))
// }
if (typeof extension.config.onDestroy === 'function') {
this.editor.on('destroy', extension.config.onDestroy.bind(context))
}
// if (typeof extension.config.onDestroy === 'function') {
// this.editor.on('destroy', extension.config.onDestroy.bind(context))
// }
})
}
@@ -96,11 +97,11 @@ export default class ExtensionManager {
get commands(): RawCommands {
return this.extensions.reduce((commands, extension) => {
const context = createExtensionContext(extension, {
const context = {
options: extension.options,
editor: this.editor,
type: getSchemaTypeByName(extension.config.name, this.schema),
})
}
if (!extension.config.addCommands) {
return commands
@@ -108,7 +109,7 @@ export default class ExtensionManager {
return {
...commands,
...extension.config.addCommands.bind(context)(),
...getExtensionField(extension, 'addCommands', context)(),
}
}, {} as RawCommands)
}
@@ -117,22 +118,34 @@ export default class ExtensionManager {
return [...this.extensions]
.reverse()
.map(extension => {
const context = createExtensionContext(extension, {
const context = {
options: extension.options,
editor: this.editor,
type: getSchemaTypeByName(extension.config.name, this.schema),
})
}
const plugins: Plugin[] = []
if (extension.config.addKeyboardShortcuts) {
const keyMapPlugin = keymap(extension.config.addKeyboardShortcuts.bind(context)())
const addKeyboardShortcuts = getExtensionField<AnyConfig['addKeyboardShortcuts']>(
extension,
'addKeyboardShortcuts',
context,
)
if (addKeyboardShortcuts) {
const keyMapPlugin = keymap(addKeyboardShortcuts())
plugins.push(keyMapPlugin)
}
if (this.editor.options.enableInputRules && extension.config.addInputRules) {
const inputRules = extension.config.addInputRules.bind(context)()
const addInputRules = getExtensionField<AnyConfig['addInputRules']>(
extension,
'addInputRules',
context,
)
if (this.editor.options.enableInputRules && addInputRules) {
const inputRules = addInputRules()
const inputRulePlugins = inputRules.length
? [inputRulesPlugin({ rules: inputRules })]
: []
@@ -140,46 +153,30 @@ export default class ExtensionManager {
plugins.push(...inputRulePlugins)
}
if (this.editor.options.enablePasteRules && extension.config.addPasteRules) {
const pasteRulePlugins = extension.config.addPasteRules.bind(context)()
const addPasteRules = getExtensionField<AnyConfig['addPasteRules']>(
extension,
'addPasteRules',
context,
)
if (this.editor.options.enablePasteRules && addPasteRules) {
const pasteRulePlugins = addPasteRules()
plugins.push(...pasteRulePlugins)
}
// console.log('has pm', extension.config.addProseMirrorPlugins, extension)
const addProseMirrorPlugins = getExtensionField<AnyConfig['addProseMirrorPlugins']>(
extension,
'addProseMirrorPlugins',
context,
)
const getItem = (rootext: any, ext: any, field: string): any => {
const realctx = createExtensionContext(ext, {
options: rootext.options,
// options: getItem(ext, 'defaultOptions'),
editor: this.editor,
type: getSchemaTypeByName(ext.config.name, this.schema),
})
if (addProseMirrorPlugins) {
const proseMirrorPlugins = addProseMirrorPlugins()
if (ext.config[field]) {
if (typeof ext.config[field] === 'function') {
return ext.config[field].bind(realctx)()
}
return ext.config[field]
}
if (ext.parent) {
return getItem(rootext, ext.parent, field)
}
return undefined
plugins.push(...proseMirrorPlugins)
}
// console.log('get PM', getItem(extension, 'addProseMirrorPlugins', context))
const realPMP = getItem(extension, extension, 'addProseMirrorPlugins')
// if (extension.config.addProseMirrorPlugins) {
// const proseMirrorPlugins = extension.config.addProseMirrorPlugins.bind(context)()
// plugins.push(...proseMirrorPlugins)
// }
return plugins
})
.flat()
@@ -194,15 +191,24 @@ export default class ExtensionManager {
const { nodeExtensions } = splitExtensions(this.extensions)
return Object.fromEntries(nodeExtensions
.filter(extension => !!extension.config.addNodeView)
.filter(extension => !!getExtensionField(extension, 'addNodeView'))
.map(extension => {
const extensionAttributes = this.attributes.filter(attribute => attribute.type === extension.config.name)
const context = createExtensionContext(extension, {
const name = getExtensionField<NodeConfig['name']>(extension, 'name')
const extensionAttributes = this.attributes.filter(attribute => attribute.type === name)
const context = {
options: extension.options,
editor,
type: getNodeType(extension.config.name, this.schema),
})
const renderer = extension.config.addNodeView?.call(context) as NodeViewRenderer
}
const addNodeView = getExtensionField<NodeConfig['addNodeView']>(
extension,
'addNodeView',
context,
)
if (!addNodeView) {
return []
}
const nodeview = (
node: ProsemirrorNode,
@@ -212,7 +218,7 @@ export default class ExtensionManager {
) => {
const HTMLAttributes = getRenderedAttributes(node, extensionAttributes)
return renderer({
return addNodeView()({
editor,
node,
getPos,
@@ -231,15 +237,21 @@ export default class ExtensionManager {
const { nodeExtensions } = splitExtensions(this.extensions)
return Object.fromEntries(nodeExtensions
.filter(extension => !!extension.config.renderText)
.filter(extension => !!getExtensionField(extension, 'renderText'))
.map(extension => {
const context = createExtensionContext(extension, {
const context = {
options: extension.options,
editor,
type: getNodeType(extension.config.name, this.schema),
})
}
const textSerializer = (props: { node: ProsemirrorNode }) => extension.config.renderText?.call(context, props)
const renderText = getExtensionField<NodeConfig['renderText']>(extension, 'renderText', context)
if (!renderText) {
return []
}
const textSerializer = (props: { node: ProsemirrorNode }) => renderText(props)
return [extension.config.name, textSerializer]
}))

View File

@@ -1,101 +0,0 @@
import { AnyExtension, AnyObject } from '../types'
// export default function createExtensionContext<T>(
// extension: AnyExtension,
// data: T,
// ): T & { parentConfig: AnyObject } {
// const context: any = {
// ...data,
// // get parentConfig() {
// // return Object.fromEntries(Object.entries(extension.parentConfig).map(([key, value]) => {
// // if (typeof value !== 'function') {
// // return [key, value]
// // }
// // console.log('call', key)
// // return [key, value.bind(context)]
// // }))
// // },
// parentConfig: Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
// if (typeof value !== 'function') {
// return [key, value]
// }
// // console.log('call', key)
// return [key, value.bind(data)]
// })),
// // get parentConfig() {
// // console.log('parent', extension.parent)
// // console.log('parent parent', extension.parent?.parent)
// // return Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
// // if (typeof value !== 'function') {
// // return [key, value]
// // }
// // // console.log('call', key)
// // return [key, value.bind(context)]
// // }))
// // },
// // parentConfig: null,
// }
// return context
// }
// export default function createExtensionContext<T>(
// extension: AnyExtension,
// data: T,
// // @ts-ignore
// ): T & { parentConfig: AnyObject } {
// const context: any = data
// if (!extension.parent) {
// context.parentConfig = {}
// return context
// }
// // const bla = {
// // ...(extension.parent.parent ? extension.parent.parent.config : {}),
// // ...extension.parent.config,
// // }
// context.parentConfig = Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
// if (typeof value !== 'function') {
// return [key, value]
// }
// // console.log('call', key)
// return [key, value.bind(createExtensionContext(extension.parent, data))]
// }))
// return context
// }
export default function createExtensionContext<T>(
extension: AnyExtension,
data: T,
): T & { parentConfig: AnyObject } {
const context: any = {
...data,
get parentConfig() {
return Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
if (typeof value !== 'function') {
return [key, value]
}
return [key, value.bind(context)]
}))
},
}
return context
}

View File

@@ -1,12 +1,14 @@
import createExtensionContext from './createExtensionContext'
import splitExtensions from './splitExtensions'
import getExtensionField from './getExtensionField'
import {
Extensions,
GlobalAttributes,
Attributes,
Attribute,
ExtensionAttribute,
AnyConfig,
} from '../types'
import { NodeConfig, MarkConfig } from '..'
/**
* Get a list of all extension attributes defined in `addAttribute` and `addGlobalAttribute`.
@@ -25,15 +27,22 @@ export default function getAttributesFromExtensions(extensions: Extensions): Ext
}
extensions.forEach(extension => {
const context = createExtensionContext(extension, {
const context = {
options: extension.options,
})
}
if (!extension.config.addGlobalAttributes) {
const addGlobalAttributes = getExtensionField<AnyConfig['addGlobalAttributes']>(
extension,
'addGlobalAttributes',
context,
)
if (!addGlobalAttributes) {
return
}
const globalAttributes = extension.config.addGlobalAttributes.bind(context)() as GlobalAttributes
// TODO: remove `as GlobalAttributes`
const globalAttributes = addGlobalAttributes() as GlobalAttributes
globalAttributes.forEach(globalAttribute => {
globalAttribute.types.forEach(type => {
@@ -54,21 +63,28 @@ export default function getAttributesFromExtensions(extensions: Extensions): Ext
})
nodeAndMarkExtensions.forEach(extension => {
const context = createExtensionContext(extension, {
const context = {
options: extension.options,
})
}
if (!extension.config.addAttributes) {
const addAttributes = getExtensionField<NodeConfig['addAttributes'] | MarkConfig['addAttributes']>(
extension,
'addAttributes',
context,
)
if (!addAttributes) {
return
}
const attributes = extension.config.addAttributes.bind(context)() as Attributes
// TODO: remove `as Attributes`
const attributes = addAttributes() as Attributes
Object
.entries(attributes)
.forEach(([name, attribute]) => {
extensionAttributes.push({
type: extension.config.name,
type: getExtensionField(extension, 'name'),
name,
attribute: {
...defaultAttribute,

View File

@@ -0,0 +1,25 @@
import { AnyExtension, AnyObject, RemoveThis } from '../types'
export default function getExtensionField<T = any>(
extension: AnyExtension,
field: string,
context: AnyObject = {},
): RemoveThis<T> {
if (extension.config[field] === undefined && extension.parent) {
return getExtensionField(extension.parent, field, context)
}
if (typeof extension.config[field] === 'function') {
const value = extension.config[field].bind({
...context,
parent: extension.parent
? getExtensionField(extension.parent, field, context)
: null,
})
return value
}
return extension.config[field]
}

View File

@@ -1,13 +1,13 @@
import { NodeSpec, MarkSpec, Schema } from 'prosemirror-model'
import { Extensions } from '../types'
import { ExtensionConfig, NodeConfig, MarkConfig } from '..'
import createExtensionContext from './createExtensionContext'
import splitExtensions from './splitExtensions'
import getAttributesFromExtensions from './getAttributesFromExtensions'
import getRenderedAttributes from './getRenderedAttributes'
import isEmptyObject from '../utilities/isEmptyObject'
import injectExtensionAttributesToParseRule from './injectExtensionAttributesToParseRule'
import callOrReturn from '../utilities/callOrReturn'
import getExtensionField from './getExtensionField'
function cleanUpSchemaItem<T>(data: T) {
return Object.fromEntries(Object.entries(data).filter(([key, value]) => {
@@ -46,9 +46,9 @@ export default function getSchema(extensions: Extensions): Schema {
const nodes = Object.fromEntries(nodeExtensions.map(extension => {
const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name)
const context = createExtensionContext(extension, {
const context = {
options: extension.options,
})
}
const extraNodeFields = nodeSchemaExtenders.reduce((fields, nodeSchemaExtender) => {
const extraFields = callOrReturn(nodeSchemaExtender, context, extension)
@@ -61,29 +61,32 @@ export default function getSchema(extensions: Extensions): Schema {
const schema: NodeSpec = cleanUpSchemaItem({
...extraNodeFields,
content: callOrReturn(extension.config.content, context),
marks: callOrReturn(extension.config.marks, context),
group: callOrReturn(extension.config.group, context),
inline: callOrReturn(extension.config.inline, context),
atom: callOrReturn(extension.config.atom, context),
selectable: callOrReturn(extension.config.selectable, context),
draggable: callOrReturn(extension.config.draggable, context),
code: callOrReturn(extension.config.code, context),
defining: callOrReturn(extension.config.defining, context),
isolating: callOrReturn(extension.config.isolating, context),
content: callOrReturn(getExtensionField<NodeConfig['content']>(extension, 'content', context)),
marks: callOrReturn(getExtensionField<NodeConfig['marks']>(extension, 'marks', context)),
group: callOrReturn(getExtensionField<NodeConfig['group']>(extension, 'group', context)),
inline: callOrReturn(getExtensionField<NodeConfig['inline']>(extension, 'inline', context)),
atom: callOrReturn(getExtensionField<NodeConfig['atom']>(extension, 'atom', context)),
selectable: callOrReturn(getExtensionField<NodeConfig['selectable']>(extension, 'selectable', context)),
draggable: callOrReturn(getExtensionField<NodeConfig['draggable']>(extension, 'draggable', context)),
code: callOrReturn(getExtensionField<NodeConfig['code']>(extension, 'code', context)),
defining: callOrReturn(getExtensionField<NodeConfig['defining']>(extension, 'defining', context)),
isolating: callOrReturn(getExtensionField<NodeConfig['isolating']>(extension, 'isolating', context)),
attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => {
return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }]
})),
})
if (extension.config.parseHTML) {
schema.parseDOM = extension.config.parseHTML
.bind(context)()
?.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes))
const parseHTML = callOrReturn(getExtensionField<NodeConfig['parseHTML']>(extension, 'parseHTML', context))
if (parseHTML) {
schema.parseDOM = parseHTML
.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes))
}
if (extension.config.renderHTML) {
schema.toDOM = node => (extension.config.renderHTML as Function)?.bind(context)({
const renderHTML = getExtensionField<NodeConfig['renderHTML']>(extension, 'renderHTML', context)
if (renderHTML) {
schema.toDOM = node => renderHTML({
node,
HTMLAttributes: getRenderedAttributes(node, extensionAttributes),
})
@@ -94,9 +97,9 @@ export default function getSchema(extensions: Extensions): Schema {
const marks = Object.fromEntries(markExtensions.map(extension => {
const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name)
const context = createExtensionContext(extension, {
const context = {
options: extension.options,
})
}
const extraMarkFields = markSchemaExtenders.reduce((fields, markSchemaExtender) => {
const extraFields = callOrReturn(markSchemaExtender, context, extension)
@@ -109,23 +112,26 @@ export default function getSchema(extensions: Extensions): Schema {
const schema: MarkSpec = cleanUpSchemaItem({
...extraMarkFields,
inclusive: callOrReturn(extension.config.inclusive, context),
excludes: callOrReturn(extension.config.excludes, context),
group: callOrReturn(extension.config.group, context),
spanning: callOrReturn(extension.config.spanning, context),
inclusive: callOrReturn(getExtensionField<NodeConfig['inclusive']>(extension, 'inclusive', context)),
excludes: callOrReturn(getExtensionField<NodeConfig['excludes']>(extension, 'excludes', context)),
group: callOrReturn(getExtensionField<NodeConfig['group']>(extension, 'group', context)),
spanning: callOrReturn(getExtensionField<NodeConfig['spanning']>(extension, 'spanning', context)),
attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => {
return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }]
})),
})
if (extension.config.parseHTML) {
schema.parseDOM = extension.config.parseHTML
.bind(context)()
?.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes))
const parseHTML = callOrReturn(getExtensionField<MarkConfig['parseHTML']>(extension, 'parseHTML', context))
if (parseHTML) {
schema.parseDOM = parseHTML
.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes))
}
if (extension.config.renderHTML) {
schema.toDOM = mark => (extension.config.renderHTML as Function)?.bind(context)({
const renderHTML = getExtensionField<MarkConfig['renderHTML']>(extension, 'renderHTML', context)
if (renderHTML) {
schema.toDOM = mark => renderHTML({
mark,
HTMLAttributes: getRenderedAttributes(mark, extensionAttributes),
})

View File

@@ -12,7 +12,7 @@ export { default as markPasteRule } from './pasteRules/markPasteRule'
export { default as callOrReturn } from './utilities/callOrReturn'
export { default as mergeAttributes } from './utilities/mergeAttributes'
export { default as createExtensionContext } from './helpers/createExtensionContext'
export { default as getExtensionField } from './helpers/getExtensionField'
export { default as findChildren } from './helpers/findChildren'
export { default as findParentNode } from './helpers/findParentNode'
export { default as findParentNodeClosestToPos } from './helpers/findParentNodeClosestToPos'

View File

@@ -14,17 +14,31 @@ import { Extension } from './Extension'
import { Node } from './Node'
import { Mark } from './Mark'
import { Editor } from './Editor'
import { Commands } from '.'
import {
Commands,
ExtensionConfig,
NodeConfig,
MarkConfig,
} from '.'
export type AnyConfig = ExtensionConfig | NodeConfig | MarkConfig
export type AnyExtension = Extension | Node | Mark
export type Extensions = AnyExtension[]
export type ParentConfig<T> = Partial<{
[P in keyof T]: Required<T>[P] extends () => any
[P in keyof T]: Required<T>[P] extends (...args: any) => any
? (...args: Parameters<Required<T>[P]>) => ReturnType<Required<T>[P]>
: T[P]
}>
export type RemoveThis<T> = T extends (...args: any) => any
? (...args: Parameters<T>) => ReturnType<T>
: T
export type MaybeReturnType<T> = T extends (...args: any) => any
? ReturnType<T>
: T
export interface EditorOptions {
element: Element,
content: Content,

View File

@@ -1,3 +1,5 @@
import { MaybeReturnType } from '../types'
/**
* Optionally calls `value` as a function.
* Otherwise it is returned directly.
@@ -5,7 +7,7 @@
* @param context Optional context to bind to function.
* @param props Optional props to pass to function.
*/
export default function callOrReturn(value: any, context: any = undefined, ...props: any[]): any {
export default function callOrReturn<T>(value: T, context: any = undefined, ...props: any[]): MaybeReturnType<T> {
if (typeof value === 'function') {
if (context) {
return value.bind(context)(...props)
@@ -14,5 +16,5 @@ export default function callOrReturn(value: any, context: any = undefined, ...pr
return value(...props)
}
return value
return value as MaybeReturnType<T>
}